Skip to content

Commit

Permalink
making validation of signature of both fields as optional and adding …
Browse files Browse the repository at this point in the history
…unit tests to cover
  • Loading branch information
himran92 committed Dec 3, 2024
1 parent 3748d22 commit 5f2819d
Show file tree
Hide file tree
Showing 3 changed files with 59 additions and 21 deletions.
41 changes: 29 additions & 12 deletions saml/response.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,19 +19,21 @@ import (
)

type parseResponseOptions struct {
clock clockwork.Clock
skipRequestIDValidation bool
skipAssertionConditionValidation bool
skipSignatureValidation bool
assertionConsumerServiceURL string
clock clockwork.Clock
skipRequestIDValidation bool
skipAssertionConditionValidation bool
skipSignatureValidation bool
assertionConsumerServiceURL string
requireSignatureForResponseAndAssertion bool
}

func parseResponseOptionsDefault() parseResponseOptions {
return parseResponseOptions{
clock: clockwork.NewRealClock(),
skipRequestIDValidation: false,
skipAssertionConditionValidation: false,
skipSignatureValidation: false,
clock: clockwork.NewRealClock(),
skipRequestIDValidation: false,
skipAssertionConditionValidation: false,
skipSignatureValidation: false,
requireSignatureForResponseAndAssertion: false,
}
}

Expand Down Expand Up @@ -73,6 +75,15 @@ func InsecureSkipSignatureValidation() Option {
}
}

// RequireSignatureForBothResponseAndAssertion enables validation of both the SAML Response and its assertions.
func RequireSignatureForBothResponseAndAssertion() Option {
return func(o interface{}) {
if o, ok := o.(*parseResponseOptions); ok {
o.requireSignatureForResponseAndAssertion = true
}
}
}

// ParseResponse parses and validates a SAML Reponse.
//
// Options:
Expand All @@ -87,15 +98,19 @@ func (sp *ServiceProvider) ParseResponse(
opt ...Option,
) (*core.Response, error) {
const op = "saml.(ServiceProvider).ParseResponse"
opts := getParseResponseOptions(opt...)

switch {
case sp == nil:
return nil, fmt.Errorf("%s: missing service provider %w", op, ErrInternal)
case samlResp == "":
return nil, fmt.Errorf("%s: missing saml response: %w", op, ErrInvalidParameter)
case requestID == "":
return nil, fmt.Errorf("%s: missing request ID: %w", op, ErrInvalidParameter)
case opts.skipSignatureValidation && opts.requireSignatureForResponseAndAssertion:
return nil, fmt.Errorf("%s: option `skip signature validation` cannot be true with `require signature"+
" for response and assertion` : %w", op, ErrInvalidParameter)
}
opts := getParseResponseOptions(opt...)

// We use github.com/russellhaering/gosaml2 for SAMLResponse signature and condition validation.
ip, err := sp.internalParser(
Expand Down Expand Up @@ -152,9 +167,11 @@ func (sp *ServiceProvider) ParseResponse(
}

samlResponse := core.Response{Response: *response}
if !opts.skipSignatureValidation {
if opts.requireSignatureForResponseAndAssertion {
// func ip.ValidateEncodedResponse(...) above only requires either `response or all its `assertions` are signed,
// but does not require both. Adding another check to validate that both of these are signed always.
// but does not require both.
// If option requireSignatureForResponseAndAssertion is true, adding another check to validate that both of
// these are signed always.
if err := validateSignature(&samlResponse, op); err != nil {
return nil, err
}
Expand Down
37 changes: 29 additions & 8 deletions saml/response_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,12 +59,33 @@ func TestServiceProvider_ParseResponse(t *testing.T) {
wantErrAs error
}{
{
name: "success",
name: "success - with both response and assertion signed",
sp: testSp,
samlResp: base64.StdEncoding.EncodeToString([]byte(tp.SamlResponse(t, testprovider.WithCompleteResponseSigned()))),
samlResp: base64.StdEncoding.EncodeToString([]byte(tp.SamlResponse(t, testprovider.WithResponseAndAssertionSigned()))),
opts: []saml.Option{},
requestID: testRequestId,
},
{
name: "success - with just response signed",
sp: testSp,
samlResp: base64.StdEncoding.EncodeToString([]byte(tp.SamlResponse(t, testprovider.WithJustResponseElemSigned()))),
opts: []saml.Option{},
requestID: testRequestId,
},
{
name: "success - with just assertion signed",
sp: testSp,
samlResp: base64.StdEncoding.EncodeToString([]byte(tp.SamlResponse(t, testprovider.WithJustAssertionElemSigned()))),
opts: []saml.Option{},
requestID: testRequestId,
},
{
name: "success - with both response and assertion signed and both signature required",
sp: testSp,
samlResp: base64.StdEncoding.EncodeToString([]byte(tp.SamlResponse(t, testprovider.WithResponseAndAssertionSigned()))),
opts: []saml.Option{saml.RequireSignatureForBothResponseAndAssertion()},
requestID: testRequestId,
},
{
name: "missing signature",
sp: testSp,
Expand All @@ -74,18 +95,18 @@ func TestServiceProvider_ParseResponse(t *testing.T) {
wantErrContains: "response and/or assertions must be signed",
},
{
name: "error-invalid-signature",
name: "error-invalid-signature - with just response signed",
sp: testSp,
samlResp: base64.StdEncoding.EncodeToString([]byte(tp.SamlResponse(t, testprovider.WithJustResponseElemSigned()))),
opts: []saml.Option{},
opts: []saml.Option{saml.RequireSignatureForBothResponseAndAssertion()},
requestID: testRequestId,
wantErrContains: "invalid signature",
},
{
name: "error-invalid-signature",
name: "error-invalid-signature - with just assertion signed",
sp: testSp,
samlResp: base64.StdEncoding.EncodeToString([]byte(tp.SamlResponse(t, testprovider.WithJustAssertionElemSigned()))),
opts: []saml.Option{},
opts: []saml.Option{saml.RequireSignatureForBothResponseAndAssertion()},
requestID: testRequestId,
wantErrContains: "invalid signature",
},
Expand Down Expand Up @@ -167,15 +188,15 @@ func TestServiceProvider_ParseResponse(t *testing.T) {
{
name: "err-in-response-to",
sp: testSp,
samlResp: base64.StdEncoding.EncodeToString([]byte(tp.SamlResponse(t, testprovider.WithCompleteResponseSigned()))),
samlResp: base64.StdEncoding.EncodeToString([]byte(tp.SamlResponse(t, testprovider.WithResponseAndAssertionSigned()))),
requestID: "invalid-request-id",
wantErrContains: "doesn't match the expected requestID (invalid-request-id)",
},
{
name: "expired",
sp: testSp,
samlResp: base64.StdEncoding.EncodeToString([]byte(tp.SamlResponse(t,
testprovider.WithCompleteResponseSigned(),
testprovider.WithResponseAndAssertionSigned(),
testprovider.WithResponseExpired(),
))),
requestID: "request-id",
Expand Down
2 changes: 1 addition & 1 deletion saml/test/provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -451,7 +451,7 @@ func defaultResponseOptions() *responseOptions {
return &responseOptions{}
}

func WithCompleteResponseSigned() ResponseOption {
func WithResponseAndAssertionSigned() ResponseOption {
return func(o *responseOptions) {
o.signResponseElem = true
o.signAssertionElem = true
Expand Down

0 comments on commit 5f2819d

Please sign in to comment.