diff --git a/saml/response.go b/saml/response.go index cdd4264..4592db7 100644 --- a/saml/response.go +++ b/saml/response.go @@ -19,19 +19,21 @@ import ( ) type parseResponseOptions struct { - clock clockwork.Clock - skipRequestIDValidation bool - skipAssertionConditionValidation bool - skipSignatureValidation bool - assertionConsumerServiceURL string + clock clockwork.Clock + skipRequestIDValidation bool + skipAssertionConditionValidation bool + skipSignatureValidation bool + assertionConsumerServiceURL string + requireSignatureForResponseAndAssertion bool } func parseResponseOptionsDefault() parseResponseOptions { return parseResponseOptions{ - clock: clockwork.NewRealClock(), - skipRequestIDValidation: false, - skipAssertionConditionValidation: false, - skipSignatureValidation: false, + clock: clockwork.NewRealClock(), + skipRequestIDValidation: false, + skipAssertionConditionValidation: false, + skipSignatureValidation: false, + requireSignatureForResponseAndAssertion: false, } } @@ -73,6 +75,15 @@ func InsecureSkipSignatureValidation() Option { } } +// RequireSignatureForBothResponseAndAssertion enables validation of both the SAML Response and its assertions. +func RequireSignatureForBothResponseAndAssertion() Option { + return func(o interface{}) { + if o, ok := o.(*parseResponseOptions); ok { + o.requireSignatureForResponseAndAssertion = true + } + } +} + // ParseResponse parses and validates a SAML Reponse. // // Options: @@ -87,6 +98,8 @@ func (sp *ServiceProvider) ParseResponse( opt ...Option, ) (*core.Response, error) { const op = "saml.(ServiceProvider).ParseResponse" + opts := getParseResponseOptions(opt...) + switch { case sp == nil: return nil, fmt.Errorf("%s: missing service provider %w", op, ErrInternal) @@ -94,8 +107,10 @@ func (sp *ServiceProvider) ParseResponse( return nil, fmt.Errorf("%s: missing saml response: %w", op, ErrInvalidParameter) case requestID == "": return nil, fmt.Errorf("%s: missing request ID: %w", op, ErrInvalidParameter) + case opts.skipSignatureValidation && opts.requireSignatureForResponseAndAssertion: + return nil, fmt.Errorf("%s: option `skip signature validation` cannot be true with `require signature"+ + " for response and assertion` : %w", op, ErrInvalidParameter) } - opts := getParseResponseOptions(opt...) // We use github.com/russellhaering/gosaml2 for SAMLResponse signature and condition validation. ip, err := sp.internalParser( @@ -152,9 +167,11 @@ func (sp *ServiceProvider) ParseResponse( } samlResponse := core.Response{Response: *response} - if !opts.skipSignatureValidation { + if opts.requireSignatureForResponseAndAssertion { // func ip.ValidateEncodedResponse(...) above only requires either `response or all its `assertions` are signed, - // but does not require both. Adding another check to validate that both of these are signed always. + // but does not require both. + // If option requireSignatureForResponseAndAssertion is true, adding another check to validate that both of + // these are signed always. if err := validateSignature(&samlResponse, op); err != nil { return nil, err } diff --git a/saml/response_test.go b/saml/response_test.go index 63842bd..28efa65 100644 --- a/saml/response_test.go +++ b/saml/response_test.go @@ -59,12 +59,33 @@ func TestServiceProvider_ParseResponse(t *testing.T) { wantErrAs error }{ { - name: "success", + name: "success - with both response and assertion signed", sp: testSp, - samlResp: base64.StdEncoding.EncodeToString([]byte(tp.SamlResponse(t, testprovider.WithCompleteResponseSigned()))), + samlResp: base64.StdEncoding.EncodeToString([]byte(tp.SamlResponse(t, testprovider.WithResponseAndAssertionSigned()))), opts: []saml.Option{}, requestID: testRequestId, }, + { + name: "success - with just response signed", + sp: testSp, + samlResp: base64.StdEncoding.EncodeToString([]byte(tp.SamlResponse(t, testprovider.WithJustResponseElemSigned()))), + opts: []saml.Option{}, + requestID: testRequestId, + }, + { + name: "success - with just assertion signed", + sp: testSp, + samlResp: base64.StdEncoding.EncodeToString([]byte(tp.SamlResponse(t, testprovider.WithJustAssertionElemSigned()))), + opts: []saml.Option{}, + requestID: testRequestId, + }, + { + name: "success - with both response and assertion signed and both signature required", + sp: testSp, + samlResp: base64.StdEncoding.EncodeToString([]byte(tp.SamlResponse(t, testprovider.WithResponseAndAssertionSigned()))), + opts: []saml.Option{saml.RequireSignatureForBothResponseAndAssertion()}, + requestID: testRequestId, + }, { name: "missing signature", sp: testSp, @@ -74,18 +95,18 @@ func TestServiceProvider_ParseResponse(t *testing.T) { wantErrContains: "response and/or assertions must be signed", }, { - name: "error-invalid-signature", + name: "error-invalid-signature - with just response signed", sp: testSp, samlResp: base64.StdEncoding.EncodeToString([]byte(tp.SamlResponse(t, testprovider.WithJustResponseElemSigned()))), - opts: []saml.Option{}, + opts: []saml.Option{saml.RequireSignatureForBothResponseAndAssertion()}, requestID: testRequestId, wantErrContains: "invalid signature", }, { - name: "error-invalid-signature", + name: "error-invalid-signature - with just assertion signed", sp: testSp, samlResp: base64.StdEncoding.EncodeToString([]byte(tp.SamlResponse(t, testprovider.WithJustAssertionElemSigned()))), - opts: []saml.Option{}, + opts: []saml.Option{saml.RequireSignatureForBothResponseAndAssertion()}, requestID: testRequestId, wantErrContains: "invalid signature", }, @@ -167,7 +188,7 @@ func TestServiceProvider_ParseResponse(t *testing.T) { { name: "err-in-response-to", sp: testSp, - samlResp: base64.StdEncoding.EncodeToString([]byte(tp.SamlResponse(t, testprovider.WithCompleteResponseSigned()))), + samlResp: base64.StdEncoding.EncodeToString([]byte(tp.SamlResponse(t, testprovider.WithResponseAndAssertionSigned()))), requestID: "invalid-request-id", wantErrContains: "doesn't match the expected requestID (invalid-request-id)", }, @@ -175,7 +196,7 @@ func TestServiceProvider_ParseResponse(t *testing.T) { name: "expired", sp: testSp, samlResp: base64.StdEncoding.EncodeToString([]byte(tp.SamlResponse(t, - testprovider.WithCompleteResponseSigned(), + testprovider.WithResponseAndAssertionSigned(), testprovider.WithResponseExpired(), ))), requestID: "request-id", diff --git a/saml/test/provider.go b/saml/test/provider.go index fa4d8f5..9b8a35d 100644 --- a/saml/test/provider.go +++ b/saml/test/provider.go @@ -451,7 +451,7 @@ func defaultResponseOptions() *responseOptions { return &responseOptions{} } -func WithCompleteResponseSigned() ResponseOption { +func WithResponseAndAssertionSigned() ResponseOption { return func(o *responseOptions) { o.signResponseElem = true o.signAssertionElem = true