Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added context support #176

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 22 additions & 14 deletions access.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package osin

import (
"context"
"crypto/sha256"
"encoding/base64"
"errors"
Expand Down Expand Up @@ -151,6 +152,7 @@ func (s *Server) HandleAccessRequest(w *Response, r *http.Request) *AccessReques

func (s *Server) handleAuthorizationCodeRequest(w *Response, r *http.Request) *AccessRequest {
// get client authentication
ctx := r.Context()
auth := s.getClientAuth(w, r, s.Config.AllowClientSecretInParams)
if auth == nil {
return nil
Expand All @@ -174,13 +176,13 @@ func (s *Server) handleAuthorizationCodeRequest(w *Response, r *http.Request) *A
}

// must have a valid client
if ret.Client = s.getClient(auth, w.Storage, w); ret.Client == nil {
if ret.Client = s.getClient(ctx, auth, w.Storage, w); ret.Client == nil {
return nil
}

// must be a valid authorization code
var err error
ret.AuthorizeData, err = w.Storage.LoadAuthorize(ret.Code)
ret.AuthorizeData, err = w.Storage.LoadAuthorize(ctx, ret.Code)
if err != nil {
s.setErrorAndLog(w, E_INVALID_GRANT, err, "auth_code_request=%s", "error loading authorize data")
return nil
Expand Down Expand Up @@ -283,6 +285,7 @@ func extraScopes(access_scopes, refresh_scopes string) bool {

func (s *Server) handleRefreshTokenRequest(w *Response, r *http.Request) *AccessRequest {
// get client authentication
ctx := r.Context()
auth := s.getClientAuth(w, r, s.Config.AllowClientSecretInParams)
if auth == nil {
return nil
Expand All @@ -305,13 +308,13 @@ func (s *Server) handleRefreshTokenRequest(w *Response, r *http.Request) *Access
}

// must have a valid client
if ret.Client = s.getClient(auth, w.Storage, w); ret.Client == nil {
if ret.Client = s.getClient(ctx, auth, w.Storage, w); ret.Client == nil {
return nil
}

// must be a valid refresh code
var err error
ret.AccessData, err = w.Storage.LoadRefresh(ret.Code)
ret.AccessData, err = w.Storage.LoadRefresh(ctx, ret.Code)
if err != nil {
s.setErrorAndLog(w, E_INVALID_GRANT, err, "refresh_token=%s", "error loading access data")
return nil
Expand Down Expand Up @@ -354,6 +357,7 @@ func (s *Server) handleRefreshTokenRequest(w *Response, r *http.Request) *Access

func (s *Server) handlePasswordRequest(w *Response, r *http.Request) *AccessRequest {
// get client authentication
ctx := r.Context()
auth := s.getClientAuth(w, r, s.Config.AllowClientSecretInParams)
if auth == nil {
return nil
Expand All @@ -377,7 +381,7 @@ func (s *Server) handlePasswordRequest(w *Response, r *http.Request) *AccessRequ
}

// must have a valid client
if ret.Client = s.getClient(auth, w.Storage, w); ret.Client == nil {
if ret.Client = s.getClient(ctx, auth, w.Storage, w); ret.Client == nil {
return nil
}

Expand All @@ -389,6 +393,7 @@ func (s *Server) handlePasswordRequest(w *Response, r *http.Request) *AccessRequ

func (s *Server) handleClientCredentialsRequest(w *Response, r *http.Request) *AccessRequest {
// get client authentication
ctx := r.Context()
auth := s.getClientAuth(w, r, s.Config.AllowClientSecretInParams)
if auth == nil {
return nil
Expand All @@ -404,7 +409,7 @@ func (s *Server) handleClientCredentialsRequest(w *Response, r *http.Request) *A
}

// must have a valid client
if ret.Client = s.getClient(auth, w.Storage, w); ret.Client == nil {
if ret.Client = s.getClient(ctx, auth, w.Storage, w); ret.Client == nil {
return nil
}

Expand All @@ -416,6 +421,7 @@ func (s *Server) handleClientCredentialsRequest(w *Response, r *http.Request) *A

func (s *Server) handleAssertionRequest(w *Response, r *http.Request) *AccessRequest {
// get client authentication
ctx := r.Context()
auth := s.getClientAuth(w, r, s.Config.AllowClientSecretInParams)
if auth == nil {
return nil
Expand All @@ -439,7 +445,7 @@ func (s *Server) handleAssertionRequest(w *Response, r *http.Request) *AccessReq
}

// must have a valid client
if ret.Client = s.getClient(auth, w.Storage, w); ret.Client == nil {
if ret.Client = s.getClient(ctx, auth, w.Storage, w); ret.Client == nil {
return nil
}

Expand All @@ -451,6 +457,7 @@ func (s *Server) handleAssertionRequest(w *Response, r *http.Request) *AccessReq

func (s *Server) FinishAccessRequest(w *Response, r *http.Request, ar *AccessRequest) {
// don't process if is already an error
ctx := r.Context()
if w.IsError {
return
}
Expand Down Expand Up @@ -487,22 +494,22 @@ func (s *Server) FinishAccessRequest(w *Response, r *http.Request, ar *AccessReq
}

// save access token
if err = w.Storage.SaveAccess(ret); err != nil {
if err = w.Storage.SaveAccess(ctx, ret); err != nil {
s.setErrorAndLog(w, E_SERVER_ERROR, err, "finish_access_request=%s", "error saving access token")
return
}

// remove authorization token
if ret.AuthorizeData != nil {
w.Storage.RemoveAuthorize(ret.AuthorizeData.Code)
w.Storage.RemoveAuthorize(ctx, ret.AuthorizeData.Code)
}

// remove previous access token
if ret.AccessData != nil && !s.Config.RetainTokenAfterRefresh {
if ret.AccessData.RefreshToken != "" {
w.Storage.RemoveRefresh(ret.AccessData.RefreshToken)
w.Storage.RemoveRefresh(ctx, ret.AccessData.RefreshToken)
}
w.Storage.RemoveAccess(ret.AccessData.AccessToken)
w.Storage.RemoveAccess(ctx, ret.AccessData.AccessToken)
}

// output data
Expand All @@ -524,8 +531,9 @@ func (s *Server) FinishAccessRequest(w *Response, r *http.Request, ar *AccessReq

// getClient looks up and authenticates the basic auth using the given
// storage. Sets an error on the response if auth fails or a server error occurs.
func (s Server) getClient(auth *BasicAuth, storage Storage, w *Response) Client {
client, err := storage.GetClient(auth.Username)
func (s Server) getClient(ctx context.Context, auth *BasicAuth, storage Storage, w *Response) Client {

client, err := storage.GetClient(ctx, auth.Username)
if err == ErrNotFound {
s.setErrorAndLog(w, E_UNAUTHORIZED_CLIENT, nil, "get_client=%s", "not found")
return nil
Expand Down Expand Up @@ -559,4 +567,4 @@ func (s Server) setErrorAndLog(w *Response, responseError string, internalError
w.SetError(responseError, "")

s.Logger.Printf(format, append([]interface{}{responseError, internalError}, debugArgs...)...)
}
}
22 changes: 12 additions & 10 deletions access_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package osin

import (
"context"
"net/http"
"net/url"
"testing"
Expand Down Expand Up @@ -79,7 +80,7 @@ func TestAccessRefreshToken(t *testing.T) {
}
//fmt.Printf("%+v", resp)

if _, err := server.Storage.LoadRefresh("r9999"); err == nil {
if _, err := server.Storage.LoadRefresh(req.Context(), "r9999"); err == nil {
t.Fatalf("token was not deleted")
}

Expand Down Expand Up @@ -130,7 +131,7 @@ func TestAccessRefreshTokenSaveToken(t *testing.T) {
}
//fmt.Printf("%+v", resp)

if _, err := server.Storage.LoadRefresh("r9999"); err != nil {
if _, err := server.Storage.LoadRefresh(req.Context(), "r9999"); err != nil {
t.Fatalf("token incorrectly deleted: %s", err.Error())
}

Expand Down Expand Up @@ -293,15 +294,15 @@ func TestGetClientWithoutMatcher(t *testing.T) {
storage := &TestingStorage{clients: map[string]Client{myclient.Id: myclient}}
sconfig := NewServerConfig()
server := NewServer(sconfig, storage)

ctx := context.Background()
// Ensure bad secret fails
{
auth := &BasicAuth{
Username: "myclient",
Password: "invalidsecret",
}
w := &Response{}
client := server.getClient(auth, storage, w)
client := server.getClient(ctx, auth, storage, w)
if client != nil {
t.Errorf("Expected error, got client: %v", client)
}
Expand All @@ -322,7 +323,7 @@ func TestGetClientWithoutMatcher(t *testing.T) {
Password: "nonexistent",
}
w := &Response{}
client := server.getClient(auth, storage, w)
client := server.getClient(ctx, auth, storage, w)
if client != nil {
t.Errorf("Expected error, got client: %v", client)
}
Expand All @@ -343,7 +344,7 @@ func TestGetClientWithoutMatcher(t *testing.T) {
Password: "myclientsecret",
}
w := &Response{}
client := server.getClient(auth, storage, w)
client := server.getClient(ctx, auth, storage, w)
if client != myclient {
t.Errorf("Expected client, got nil with response: %v", w)
}
Expand Down Expand Up @@ -374,15 +375,15 @@ func TestGetClientSecretMatcher(t *testing.T) {
storage := &TestingStorage{clients: map[string]Client{myclient.Id: myclient}}
sconfig := NewServerConfig()
server := NewServer(sconfig, storage)

ctx := context.Background()
// Ensure bad secret fails, but does not panic (doesn't call GetSecret)
{
auth := &BasicAuth{
Username: "myclient",
Password: "invalidsecret",
}
w := &Response{}
client := server.getClient(auth, storage, w)
client := server.getClient(ctx, auth, storage, w)
if client != nil {
t.Errorf("Expected error, got client: %v", client)
}
Expand All @@ -395,7 +396,7 @@ func TestGetClientSecretMatcher(t *testing.T) {
Password: "myclientsecret",
}
w := &Response{}
client := server.getClient(auth, storage, w)
client := server.getClient(ctx, auth, storage, w)
if client != myclient {
t.Errorf("Expected client, got nil with response: %v", w)
}
Expand Down Expand Up @@ -437,12 +438,13 @@ func TestAccessAuthorizationCodePKCE(t *testing.T) {
}

for k, test := range testcases {
ctx := context.Background()
testStorage := NewTestingStorage()
sconfig := NewServerConfig()
sconfig.AllowedAccessTypes = AllowedAccessType{AUTHORIZATION_CODE}
server := NewServer(sconfig, testStorage)
server.AccessTokenGen = &TestingAccessTokenGen{}
server.Storage.SaveAuthorize(&AuthorizeData{
server.Storage.SaveAuthorize(ctx, &AuthorizeData{
Client: testStorage.clients["public-client"],
Code: "pkce-code",
ExpiresIn: 3600,
Expand Down
7 changes: 4 additions & 3 deletions authorize.go
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ type AuthorizeTokenGen interface {
// authorization requests
func (s *Server) HandleAuthorizeRequest(w *Response, r *http.Request) *AuthorizeRequest {
r.ParseForm()

ctx := r.Context()
// create the authorization request
unescapedUri, err := url.QueryUnescape(r.Form.Get("redirect_uri"))
if err != nil {
Expand All @@ -123,7 +123,7 @@ func (s *Server) HandleAuthorizeRequest(w *Response, r *http.Request) *Authorize
}

// must have a valid client
ret.Client, err = w.Storage.GetClient(r.Form.Get("client_id"))
ret.Client, err = w.Storage.GetClient(ctx, r.Form.Get("client_id"))
if err == ErrNotFound {
w.SetErrorState(E_UNAUTHORIZED_CLIENT, "", ret.State)
return nil
Expand Down Expand Up @@ -205,6 +205,7 @@ func (s *Server) HandleAuthorizeRequest(w *Response, r *http.Request) *Authorize

func (s *Server) FinishAuthorizeRequest(w *Response, r *http.Request, ar *AuthorizeRequest) {
// don't process if is already an error
ctx := r.Context()
if w.IsError {
return
}
Expand Down Expand Up @@ -258,7 +259,7 @@ func (s *Server) FinishAuthorizeRequest(w *Response, r *http.Request, ar *Author
ret.Code = code

// save authorization token
if err = w.Storage.SaveAuthorize(ret); err != nil {
if err = w.Storage.SaveAuthorize(ctx, ret); err != nil {
w.SetErrorState(E_SERVER_ERROR, "", ar.State)
w.InternalError = err
return
Expand Down
4 changes: 2 additions & 2 deletions authorize_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -227,7 +227,7 @@ func TestAuthorizeCodePKCEPlain(t *testing.T) {
t.Fatalf("Unexpected authorization code: %s", code)
}

token, err := server.Storage.LoadAuthorize(code)
token, err := server.Storage.LoadAuthorize(req.Context(), code)
if err != nil {
t.Fatalf("Unexpected error: %s", err)
}
Expand Down Expand Up @@ -283,7 +283,7 @@ func TestAuthorizeCodePKCES256(t *testing.T) {
t.Fatalf("Unexpected authorization code: %s", code)
}

token, err := server.Storage.LoadAuthorize(code)
token, err := server.Storage.LoadAuthorize(req.Context(), code)
if err != nil {
t.Fatalf("Unexpected error: %s", err)
}
Expand Down
3 changes: 2 additions & 1 deletion info.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ type InfoRequest struct {
// NOT an RFC specification.
func (s *Server) HandleInfoRequest(w *Response, r *http.Request) *InfoRequest {
r.ParseForm()
ctx := r.Context()
bearer := CheckBearerAuth(r)
if bearer == nil {
s.setErrorAndLog(w, E_INVALID_REQUEST, nil, "handle_info_request=%s", "bearer is nil")
Expand All @@ -34,7 +35,7 @@ func (s *Server) HandleInfoRequest(w *Response, r *http.Request) *InfoRequest {
var err error

// load access data
ret.AccessData, err = w.Storage.LoadAccess(ret.Code)
ret.AccessData, err = w.Storage.LoadAccess(ctx, ret.Code)
if err != nil {
s.setErrorAndLog(w, E_INVALID_REQUEST, err, "handle_info_request=%s", "failed to load access data")
return nil
Expand Down
3 changes: 2 additions & 1 deletion response.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package osin

import (
"context"
"errors"
"fmt"
"net/http"
Expand Down Expand Up @@ -44,7 +45,7 @@ func NewResponse(storage Storage) *Response {
Output: make(ResponseData),
Headers: make(http.Header),
IsError: false,
Storage: storage.Clone(),
Storage: storage.Clone(context.TODO()),
}
r.Headers.Add(
"Cache-Control",
Expand Down
Loading