Skip to content

Commit

Permalink
fix !ok authentication bug
Browse files Browse the repository at this point in the history
  • Loading branch information
SamMHD committed Feb 24, 2024
1 parent 9eeaab7 commit 3870b41
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 22 deletions.
33 changes: 18 additions & 15 deletions pkg/auth/authenticator.go
Original file line number Diff line number Diff line change
Expand Up @@ -114,17 +114,17 @@ func (a *Authenticator) readToken(request *Request, wsvc WebservicesCacheEntry)

// readService reads requested webservice from cache and
// will return error if the object would not be found in cache
func (a *Authenticator) readService(wsvc string) (bool, CerberusReason, WebservicesCacheEntry) {
func (a *Authenticator) readService(wsvc string) (CerberusReason, WebservicesCacheEntry) {
a.cacheLock.RLock()
cacheReaders.Inc()
defer a.cacheLock.RUnlock()
defer cacheReaders.Dec()

res, ok := a.webservicesCache.ReadWebservice(wsvc)
if !ok {
return false, CerberusReasonWebserviceNotFound, WebservicesCacheEntry{}
return CerberusReasonWebserviceNotFound, WebservicesCacheEntry{}
}
return true, "", res
return "", res
}

func toExtraHeaders(headers CerberusExtraHeaders) ExtraHeaders {
Expand All @@ -139,7 +139,7 @@ func toExtraHeaders(headers CerberusExtraHeaders) ExtraHeaders {
func (a *Authenticator) Check(ctx context.Context, request *Request) (*Response, error) {
wsvc, ns, reason := readRequestContext(request)
if reason != "" {
return generateResponse(false, reason, nil), nil
return generateResponse(reason, nil), nil
}
wsvc = v1alpha1.WebserviceReference{
Name: wsvc,
Expand All @@ -149,14 +149,14 @@ func (a *Authenticator) Check(ctx context.Context, request *Request) (*Response,
request.Context[HasUpstreamAuth] = "false"
var extraHeaders ExtraHeaders

ok, reason, wsvcCacheEntry := a.readService(wsvc)
if ok {
reason, wsvcCacheEntry := a.readService(wsvc)
if reason == "" {
var cerberusExtraHeaders CerberusExtraHeaders
reason, cerberusExtraHeaders = a.TestAccess(request, wsvcCacheEntry)
extraHeaders = toExtraHeaders(cerberusExtraHeaders)
if reason == CerberusReasonOK && hasUpstreamAuth(wsvcCacheEntry) {
request.Context[HasUpstreamAuth] = "true"
ok, reason = a.checkServiceUpstreamAuth(wsvcCacheEntry, request, &extraHeaders, ctx)
reason = a.checkServiceUpstreamAuth(wsvcCacheEntry, request, &extraHeaders, ctx)
}
}

Expand All @@ -165,7 +165,7 @@ func (a *Authenticator) Check(ctx context.Context, request *Request) (*Response,
err = status.Error(codes.DeadlineExceeded, "Timeout exceeded")
}

return generateResponse(ok, reason, extraHeaders), err
return generateResponse(reason, extraHeaders), err
}

func readRequestContext(request *Request) (wsvc string, ns string, reason CerberusReason) {
Expand Down Expand Up @@ -271,17 +271,17 @@ func processResponseError(err error) CerberusReason {

// checkServiceUpstreamAuth function is designed to validate the request through
// the upstream authentication for a given webservice
func (a *Authenticator) checkServiceUpstreamAuth(service WebservicesCacheEntry, request *Request, extraHeaders *ExtraHeaders, ctx context.Context) (bool, CerberusReason) {
func (a *Authenticator) checkServiceUpstreamAuth(service WebservicesCacheEntry, request *Request, extraHeaders *ExtraHeaders, ctx context.Context) CerberusReason {
downstreamDeadline, hasDownstreamDeadline := ctx.Deadline()
serviceUpstreamAuthCalls.With(AddWithDownstreamDeadline(nil, hasDownstreamDeadline)).Inc()

if reason := validateUpstreamAuthRequest(service); reason != "" {
return false, reason
return reason
}
upstreamAuth := service.Spec.UpstreamHttpAuth
req, err := setupUpstreamAuthRequest(&upstreamAuth, request)
if err != nil {
return false, CerberusReasonUpstreamAuthNoReq
return CerberusReasonUpstreamAuthNoReq
}
a.adjustTimeout(upstreamAuth.Timeout, downstreamDeadline, hasDownstreamDeadline)

Expand All @@ -290,18 +290,18 @@ func (a *Authenticator) checkServiceUpstreamAuth(service WebservicesCacheEntry,
reqDuration := time.Since(reqStart)

if reason := processResponseError(err); reason != "" {
return false, reason
return reason
}

labels := AddWithDownstreamDeadline(AddStatusLabel(nil, resp.StatusCode), hasDownstreamDeadline)
upstreamAuthRequestDuration.With(labels).Observe(reqDuration.Seconds())

if resp.StatusCode != http.StatusOK {
return false, CerberusReasonUnauthorized
return CerberusReasonUnauthorized
}
// add requested careHeaders to extraHeaders for response
copyUpstreamHeaders(resp, extraHeaders, service.Spec.UpstreamHttpAuth.CareHeaders)
return true, CerberusReasonOK
return ""
}

// hasUpstreamAuth evaluates whether the provided webservice
Expand All @@ -313,10 +313,13 @@ func hasUpstreamAuth(service WebservicesCacheEntry) bool {
// generateResponse initializes defaults for cerberus http result and creates a
// valid response from cerberus reasons and computed headers to inform the client
// that it has the access or not.
func generateResponse(ok bool, reason CerberusReason, extraHeaders ExtraHeaders) *Response {
func generateResponse(reason CerberusReason, extraHeaders ExtraHeaders) *Response {
ok := (reason == "")

var httpStatusCode int
if ok {
httpStatusCode = http.StatusOK
reason = CerberusReasonOK
} else {
httpStatusCode = http.StatusUnauthorized
}
Expand Down
11 changes: 4 additions & 7 deletions pkg/auth/authenticator_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -126,10 +126,7 @@ func TestReadService(t *testing.T) {

for _, tc := range testCases {
t.Run(tc.wsvc, func(t *testing.T) {
ok, reason, _ := authenticator.readService(tc.wsvc)
if ok != tc.expectedOk {
t.Errorf("Expected success: %v, Got: %v", tc.expectedOk, ok)
}
reason, _ := authenticator.readService(tc.wsvc)
if reason != tc.expectedReason {
t.Errorf("Expected reason: %v, Got: %v", tc.expectedReason, reason)
}
Expand Down Expand Up @@ -729,11 +726,11 @@ func Test_generateResponse(t *testing.T) {
StatusCode: http.StatusOK,
Header: http.Header{
ExternalAuthHandlerHeader: {"cerberus"},
CerberusHeaderReasonHeader: {"reason"},
CerberusHeaderReasonHeader: {string(CerberusReasonOK)},
},
},
}
actualResponse := generateResponse(true, "reason", nil)
actualResponse := generateResponse("", nil)
assert.Equal(t, expectedResponse.Allow, actualResponse.Allow, "Response should be allowed")
assert.Equal(t, expectedResponse.Response.StatusCode, actualResponse.Response.StatusCode, "HTTP status code should match")
assert.Equal(t, expectedResponse.Response.Header, actualResponse.Response.Header, "Response headers should match")
Expand All @@ -751,7 +748,7 @@ func Test_generateResponse(t *testing.T) {
},
},
}
actualResponse = generateResponse(false, "reason", extraHeaders)
actualResponse = generateResponse("reason", extraHeaders)
assert.Equal(t, expectedResponse.Allow, actualResponse.Allow, "Response should not be allowed")
assert.Equal(t, expectedResponse.Response.StatusCode, actualResponse.Response.StatusCode, "HTTP status code should match")
assert.Equal(t, expectedResponse.Response.Header, actualResponse.Response.Header, "Response headers should match")
Expand Down

0 comments on commit 3870b41

Please sign in to comment.