diff --git a/access.go b/access.go index 152db9c..a11c8d3 100644 --- a/access.go +++ b/access.go @@ -1,6 +1,7 @@ package osin import ( + "context" "crypto/sha256" "encoding/base64" "errors" @@ -151,6 +152,7 @@ func (s *Server) HandleAccessRequest(w *Response, r *http.Request) *AccessReques func (s *Server) handleAuthorizationCodeRequest(w *Response, r *http.Request) *AccessRequest { // get client authentication + ctx := r.Context() auth := s.getClientAuth(w, r, s.Config.AllowClientSecretInParams) if auth == nil { return nil @@ -174,13 +176,13 @@ func (s *Server) handleAuthorizationCodeRequest(w *Response, r *http.Request) *A } // must have a valid client - if ret.Client = s.getClient(auth, w.Storage, w); ret.Client == nil { + if ret.Client = s.getClient(ctx, auth, w.Storage, w); ret.Client == nil { return nil } // must be a valid authorization code var err error - ret.AuthorizeData, err = w.Storage.LoadAuthorize(ret.Code) + ret.AuthorizeData, err = w.Storage.LoadAuthorize(ctx, ret.Code) if err != nil { s.setErrorAndLog(w, E_INVALID_GRANT, err, "auth_code_request=%s", "error loading authorize data") return nil @@ -283,6 +285,7 @@ func extraScopes(access_scopes, refresh_scopes string) bool { func (s *Server) handleRefreshTokenRequest(w *Response, r *http.Request) *AccessRequest { // get client authentication + ctx := r.Context() auth := s.getClientAuth(w, r, s.Config.AllowClientSecretInParams) if auth == nil { return nil @@ -305,13 +308,13 @@ func (s *Server) handleRefreshTokenRequest(w *Response, r *http.Request) *Access } // must have a valid client - if ret.Client = s.getClient(auth, w.Storage, w); ret.Client == nil { + if ret.Client = s.getClient(ctx, auth, w.Storage, w); ret.Client == nil { return nil } // must be a valid refresh code var err error - ret.AccessData, err = w.Storage.LoadRefresh(ret.Code) + ret.AccessData, err = w.Storage.LoadRefresh(ctx, ret.Code) if err != nil { s.setErrorAndLog(w, E_INVALID_GRANT, err, "refresh_token=%s", "error loading access data") return nil @@ -354,6 +357,7 @@ func (s *Server) handleRefreshTokenRequest(w *Response, r *http.Request) *Access func (s *Server) handlePasswordRequest(w *Response, r *http.Request) *AccessRequest { // get client authentication + ctx := r.Context() auth := s.getClientAuth(w, r, s.Config.AllowClientSecretInParams) if auth == nil { return nil @@ -377,7 +381,7 @@ func (s *Server) handlePasswordRequest(w *Response, r *http.Request) *AccessRequ } // must have a valid client - if ret.Client = s.getClient(auth, w.Storage, w); ret.Client == nil { + if ret.Client = s.getClient(ctx, auth, w.Storage, w); ret.Client == nil { return nil } @@ -389,6 +393,7 @@ func (s *Server) handlePasswordRequest(w *Response, r *http.Request) *AccessRequ func (s *Server) handleClientCredentialsRequest(w *Response, r *http.Request) *AccessRequest { // get client authentication + ctx := r.Context() auth := s.getClientAuth(w, r, s.Config.AllowClientSecretInParams) if auth == nil { return nil @@ -404,7 +409,7 @@ func (s *Server) handleClientCredentialsRequest(w *Response, r *http.Request) *A } // must have a valid client - if ret.Client = s.getClient(auth, w.Storage, w); ret.Client == nil { + if ret.Client = s.getClient(ctx, auth, w.Storage, w); ret.Client == nil { return nil } @@ -416,6 +421,7 @@ func (s *Server) handleClientCredentialsRequest(w *Response, r *http.Request) *A func (s *Server) handleAssertionRequest(w *Response, r *http.Request) *AccessRequest { // get client authentication + ctx := r.Context() auth := s.getClientAuth(w, r, s.Config.AllowClientSecretInParams) if auth == nil { return nil @@ -439,7 +445,7 @@ func (s *Server) handleAssertionRequest(w *Response, r *http.Request) *AccessReq } // must have a valid client - if ret.Client = s.getClient(auth, w.Storage, w); ret.Client == nil { + if ret.Client = s.getClient(ctx, auth, w.Storage, w); ret.Client == nil { return nil } @@ -451,6 +457,7 @@ func (s *Server) handleAssertionRequest(w *Response, r *http.Request) *AccessReq func (s *Server) FinishAccessRequest(w *Response, r *http.Request, ar *AccessRequest) { // don't process if is already an error + ctx := r.Context() if w.IsError { return } @@ -487,22 +494,22 @@ func (s *Server) FinishAccessRequest(w *Response, r *http.Request, ar *AccessReq } // save access token - if err = w.Storage.SaveAccess(ret); err != nil { + if err = w.Storage.SaveAccess(ctx, ret); err != nil { s.setErrorAndLog(w, E_SERVER_ERROR, err, "finish_access_request=%s", "error saving access token") return } // remove authorization token if ret.AuthorizeData != nil { - w.Storage.RemoveAuthorize(ret.AuthorizeData.Code) + w.Storage.RemoveAuthorize(ctx, ret.AuthorizeData.Code) } // remove previous access token if ret.AccessData != nil && !s.Config.RetainTokenAfterRefresh { if ret.AccessData.RefreshToken != "" { - w.Storage.RemoveRefresh(ret.AccessData.RefreshToken) + w.Storage.RemoveRefresh(ctx, ret.AccessData.RefreshToken) } - w.Storage.RemoveAccess(ret.AccessData.AccessToken) + w.Storage.RemoveAccess(ctx, ret.AccessData.AccessToken) } // output data @@ -524,8 +531,9 @@ func (s *Server) FinishAccessRequest(w *Response, r *http.Request, ar *AccessReq // getClient looks up and authenticates the basic auth using the given // storage. Sets an error on the response if auth fails or a server error occurs. -func (s Server) getClient(auth *BasicAuth, storage Storage, w *Response) Client { - client, err := storage.GetClient(auth.Username) +func (s Server) getClient(ctx context.Context, auth *BasicAuth, storage Storage, w *Response) Client { + + client, err := storage.GetClient(ctx, auth.Username) if err == ErrNotFound { s.setErrorAndLog(w, E_UNAUTHORIZED_CLIENT, nil, "get_client=%s", "not found") return nil @@ -559,4 +567,4 @@ func (s Server) setErrorAndLog(w *Response, responseError string, internalError w.SetError(responseError, "") s.Logger.Printf(format, append([]interface{}{responseError, internalError}, debugArgs...)...) -} \ No newline at end of file +} diff --git a/access_test.go b/access_test.go index 504e39c..caaa189 100644 --- a/access_test.go +++ b/access_test.go @@ -1,6 +1,7 @@ package osin import ( + "context" "net/http" "net/url" "testing" @@ -79,7 +80,7 @@ func TestAccessRefreshToken(t *testing.T) { } //fmt.Printf("%+v", resp) - if _, err := server.Storage.LoadRefresh("r9999"); err == nil { + if _, err := server.Storage.LoadRefresh(req.Context(), "r9999"); err == nil { t.Fatalf("token was not deleted") } @@ -130,7 +131,7 @@ func TestAccessRefreshTokenSaveToken(t *testing.T) { } //fmt.Printf("%+v", resp) - if _, err := server.Storage.LoadRefresh("r9999"); err != nil { + if _, err := server.Storage.LoadRefresh(req.Context(), "r9999"); err != nil { t.Fatalf("token incorrectly deleted: %s", err.Error()) } @@ -293,7 +294,7 @@ func TestGetClientWithoutMatcher(t *testing.T) { storage := &TestingStorage{clients: map[string]Client{myclient.Id: myclient}} sconfig := NewServerConfig() server := NewServer(sconfig, storage) - + ctx := context.Background() // Ensure bad secret fails { auth := &BasicAuth{ @@ -301,7 +302,7 @@ func TestGetClientWithoutMatcher(t *testing.T) { Password: "invalidsecret", } w := &Response{} - client := server.getClient(auth, storage, w) + client := server.getClient(ctx, auth, storage, w) if client != nil { t.Errorf("Expected error, got client: %v", client) } @@ -322,7 +323,7 @@ func TestGetClientWithoutMatcher(t *testing.T) { Password: "nonexistent", } w := &Response{} - client := server.getClient(auth, storage, w) + client := server.getClient(ctx, auth, storage, w) if client != nil { t.Errorf("Expected error, got client: %v", client) } @@ -343,7 +344,7 @@ func TestGetClientWithoutMatcher(t *testing.T) { Password: "myclientsecret", } w := &Response{} - client := server.getClient(auth, storage, w) + client := server.getClient(ctx, auth, storage, w) if client != myclient { t.Errorf("Expected client, got nil with response: %v", w) } @@ -374,7 +375,7 @@ func TestGetClientSecretMatcher(t *testing.T) { storage := &TestingStorage{clients: map[string]Client{myclient.Id: myclient}} sconfig := NewServerConfig() server := NewServer(sconfig, storage) - + ctx := context.Background() // Ensure bad secret fails, but does not panic (doesn't call GetSecret) { auth := &BasicAuth{ @@ -382,7 +383,7 @@ func TestGetClientSecretMatcher(t *testing.T) { Password: "invalidsecret", } w := &Response{} - client := server.getClient(auth, storage, w) + client := server.getClient(ctx, auth, storage, w) if client != nil { t.Errorf("Expected error, got client: %v", client) } @@ -395,7 +396,7 @@ func TestGetClientSecretMatcher(t *testing.T) { Password: "myclientsecret", } w := &Response{} - client := server.getClient(auth, storage, w) + client := server.getClient(ctx, auth, storage, w) if client != myclient { t.Errorf("Expected client, got nil with response: %v", w) } @@ -437,12 +438,13 @@ func TestAccessAuthorizationCodePKCE(t *testing.T) { } for k, test := range testcases { + ctx := context.Background() testStorage := NewTestingStorage() sconfig := NewServerConfig() sconfig.AllowedAccessTypes = AllowedAccessType{AUTHORIZATION_CODE} server := NewServer(sconfig, testStorage) server.AccessTokenGen = &TestingAccessTokenGen{} - server.Storage.SaveAuthorize(&AuthorizeData{ + server.Storage.SaveAuthorize(ctx, &AuthorizeData{ Client: testStorage.clients["public-client"], Code: "pkce-code", ExpiresIn: 3600, diff --git a/authorize.go b/authorize.go index f5c67d1..8d50b84 100644 --- a/authorize.go +++ b/authorize.go @@ -105,7 +105,7 @@ type AuthorizeTokenGen interface { // authorization requests func (s *Server) HandleAuthorizeRequest(w *Response, r *http.Request) *AuthorizeRequest { r.ParseForm() - + ctx := r.Context() // create the authorization request unescapedUri, err := url.QueryUnescape(r.Form.Get("redirect_uri")) if err != nil { @@ -123,7 +123,7 @@ func (s *Server) HandleAuthorizeRequest(w *Response, r *http.Request) *Authorize } // must have a valid client - ret.Client, err = w.Storage.GetClient(r.Form.Get("client_id")) + ret.Client, err = w.Storage.GetClient(ctx, r.Form.Get("client_id")) if err == ErrNotFound { w.SetErrorState(E_UNAUTHORIZED_CLIENT, "", ret.State) return nil @@ -205,6 +205,7 @@ func (s *Server) HandleAuthorizeRequest(w *Response, r *http.Request) *Authorize func (s *Server) FinishAuthorizeRequest(w *Response, r *http.Request, ar *AuthorizeRequest) { // don't process if is already an error + ctx := r.Context() if w.IsError { return } @@ -258,7 +259,7 @@ func (s *Server) FinishAuthorizeRequest(w *Response, r *http.Request, ar *Author ret.Code = code // save authorization token - if err = w.Storage.SaveAuthorize(ret); err != nil { + if err = w.Storage.SaveAuthorize(ctx, ret); err != nil { w.SetErrorState(E_SERVER_ERROR, "", ar.State) w.InternalError = err return diff --git a/authorize_test.go b/authorize_test.go index c4165f8..9bbfdba 100644 --- a/authorize_test.go +++ b/authorize_test.go @@ -227,7 +227,7 @@ func TestAuthorizeCodePKCEPlain(t *testing.T) { t.Fatalf("Unexpected authorization code: %s", code) } - token, err := server.Storage.LoadAuthorize(code) + token, err := server.Storage.LoadAuthorize(req.Context(), code) if err != nil { t.Fatalf("Unexpected error: %s", err) } @@ -283,7 +283,7 @@ func TestAuthorizeCodePKCES256(t *testing.T) { t.Fatalf("Unexpected authorization code: %s", code) } - token, err := server.Storage.LoadAuthorize(code) + token, err := server.Storage.LoadAuthorize(req.Context(), code) if err != nil { t.Fatalf("Unexpected error: %s", err) } diff --git a/info.go b/info.go index b3c73ca..e8a327f 100644 --- a/info.go +++ b/info.go @@ -15,6 +15,7 @@ type InfoRequest struct { // NOT an RFC specification. func (s *Server) HandleInfoRequest(w *Response, r *http.Request) *InfoRequest { r.ParseForm() + ctx := r.Context() bearer := CheckBearerAuth(r) if bearer == nil { s.setErrorAndLog(w, E_INVALID_REQUEST, nil, "handle_info_request=%s", "bearer is nil") @@ -34,7 +35,7 @@ func (s *Server) HandleInfoRequest(w *Response, r *http.Request) *InfoRequest { var err error // load access data - ret.AccessData, err = w.Storage.LoadAccess(ret.Code) + ret.AccessData, err = w.Storage.LoadAccess(ctx, ret.Code) if err != nil { s.setErrorAndLog(w, E_INVALID_REQUEST, err, "handle_info_request=%s", "failed to load access data") return nil diff --git a/response.go b/response.go index c561211..f442454 100644 --- a/response.go +++ b/response.go @@ -1,6 +1,7 @@ package osin import ( + "context" "errors" "fmt" "net/http" @@ -44,7 +45,7 @@ func NewResponse(storage Storage) *Response { Output: make(ResponseData), Headers: make(http.Header), IsError: false, - Storage: storage.Clone(), + Storage: storage.Clone(context.TODO()), } r.Headers.Add( "Cache-Control", diff --git a/storage.go b/storage.go index 425fc36..9acce82 100644 --- a/storage.go +++ b/storage.go @@ -1,6 +1,7 @@ package osin import ( + "context" "errors" ) @@ -18,42 +19,42 @@ type Storage interface { // to avoid concurrent access problems. // This is to avoid cloning the connection at each method access. // Can return itself if not a problem. - Clone() Storage + Clone(ctx context.Context) Storage // Close the resources the Storage potentially holds (using Clone for example) Close() // GetClient loads the client by id (client_id) - GetClient(id string) (Client, error) + GetClient(ctx context.Context, id string) (Client, error) // SaveAuthorize saves authorize data. - SaveAuthorize(*AuthorizeData) error + SaveAuthorize(ctx context.Context, authData *AuthorizeData) error // LoadAuthorize looks up AuthorizeData by a code. // Client information MUST be loaded together. // Optionally can return error if expired. - LoadAuthorize(code string) (*AuthorizeData, error) + LoadAuthorize(ctx context.Context, code string) (*AuthorizeData, error) // RemoveAuthorize revokes or deletes the authorization code. - RemoveAuthorize(code string) error + RemoveAuthorize(ctx context.Context, code string) error // SaveAccess writes AccessData. // If RefreshToken is not blank, it must save in a way that can be loaded using LoadRefresh. - SaveAccess(*AccessData) error + SaveAccess(ctx context.Context, accessData *AccessData) error // LoadAccess retrieves access data by token. Client information MUST be loaded together. // AuthorizeData and AccessData DON'T NEED to be loaded if not easily available. // Optionally can return error if expired. - LoadAccess(token string) (*AccessData, error) + LoadAccess(ctx context.Context, token string) (*AccessData, error) // RemoveAccess revokes or deletes an AccessData. - RemoveAccess(token string) error + RemoveAccess(ctx context.Context, token string) error // LoadRefresh retrieves refresh AccessData. Client information MUST be loaded together. // AuthorizeData and AccessData DON'T NEED to be loaded if not easily available. // Optionally can return error if expired. - LoadRefresh(token string) (*AccessData, error) + LoadRefresh(ctx context.Context, token string) (*AccessData, error) // RemoveRefresh revokes or deletes refresh AccessData. - RemoveRefresh(token string) error + RemoveRefresh(ctx context.Context, token string) error } diff --git a/storage_test.go b/storage_test.go index 639d23a..2c5d518 100644 --- a/storage_test.go +++ b/storage_test.go @@ -1,6 +1,7 @@ package osin import ( + "context" "strconv" "time" ) @@ -62,14 +63,14 @@ func NewTestingStorage() *TestingStorage { return r } -func (s *TestingStorage) Clone() Storage { +func (s *TestingStorage) Clone(ctx context.Context) Storage { return s } func (s *TestingStorage) Close() { } -func (s *TestingStorage) GetClient(id string) (Client, error) { +func (s *TestingStorage) GetClient(ctx context.Context, id string) (Client, error) { if c, ok := s.clients[id]; ok { return c, nil } @@ -81,24 +82,24 @@ func (s *TestingStorage) SetClient(id string, client Client) error { return nil } -func (s *TestingStorage) SaveAuthorize(data *AuthorizeData) error { +func (s *TestingStorage) SaveAuthorize(ctx context.Context, data *AuthorizeData) error { s.authorize[data.Code] = data return nil } -func (s *TestingStorage) LoadAuthorize(code string) (*AuthorizeData, error) { +func (s *TestingStorage) LoadAuthorize(ctx context.Context, code string) (*AuthorizeData, error) { if d, ok := s.authorize[code]; ok { return d, nil } return nil, ErrNotFound } -func (s *TestingStorage) RemoveAuthorize(code string) error { +func (s *TestingStorage) RemoveAuthorize(ctx context.Context, code string) error { delete(s.authorize, code) return nil } -func (s *TestingStorage) SaveAccess(data *AccessData) error { +func (s *TestingStorage) SaveAccess(ctx context.Context, data *AccessData) error { s.access[data.AccessToken] = data if data.RefreshToken != "" { s.refresh[data.RefreshToken] = data.AccessToken @@ -106,26 +107,26 @@ func (s *TestingStorage) SaveAccess(data *AccessData) error { return nil } -func (s *TestingStorage) LoadAccess(code string) (*AccessData, error) { +func (s *TestingStorage) LoadAccess(ctx context.Context, code string) (*AccessData, error) { if d, ok := s.access[code]; ok { return d, nil } return nil, ErrNotFound } -func (s *TestingStorage) RemoveAccess(code string) error { +func (s *TestingStorage) RemoveAccess(ctx context.Context, code string) error { delete(s.access, code) return nil } -func (s *TestingStorage) LoadRefresh(code string) (*AccessData, error) { +func (s *TestingStorage) LoadRefresh(ctx context.Context, code string) (*AccessData, error) { if d, ok := s.refresh[code]; ok { - return s.LoadAccess(d) + return s.LoadAccess(ctx, d) } return nil, ErrNotFound } -func (s *TestingStorage) RemoveRefresh(code string) error { +func (s *TestingStorage) RemoveRefresh(ctx context.Context, code string) error { delete(s.refresh, code) return nil }