From 340d6e408b05ed0b9ae2d26e97db9876df8d6dcd Mon Sep 17 00:00:00 2001 From: Hafsa Imran Date: Thu, 28 Nov 2024 13:05:13 -0500 Subject: [PATCH] small cleanup + small fix to test --- saml/response.go | 11 +++++------ saml/response_test.go | 2 +- saml/test/provider.go | 6 +++--- 3 files changed, 9 insertions(+), 10 deletions(-) diff --git a/saml/response.go b/saml/response.go index f1fc384..d57522f 100644 --- a/saml/response.go +++ b/saml/response.go @@ -110,7 +110,6 @@ func (sp *ServiceProvider) ParseResponse( // This will validate the response and all assertions. response, err := ip.ValidateEncodedResponse(samlResp) - switch { case err != nil: return nil, fmt.Errorf("%s: unable to validate encoded response: %w", op, err) @@ -154,7 +153,7 @@ func (sp *ServiceProvider) ParseResponse( } if !opts.skipSignatureValidation { - // func ip.ValidateEncodedResponse(...) above only requires either response or all its assertions are signed, + // func ip.ValidateEncodedResponse(...) above only requires either `response or all its `assertions` are signed, // but does not require both. Adding another check to validate that both of these are signed always. if err := validateSignature(response, op); err != nil { return nil, err @@ -257,17 +256,17 @@ func parsePEMCertificate(cert []byte) (*x509.Certificate, error) { } func validateSignature(response *types.Response, op string) error { - // validate child attr assertions + // validate child object assertions for _, assert := range response.Assertions { if !assert.SignatureValidated { // note: at one time func ip.ValidateEncodedResponse(...) above allows all signed or all unsigned - // assertions, and will give error if there are both. We are still looping on all assertions instead of - // retrieving value for one assertion, so we do not depend on dependency implementation. + // assertions, and will give error if there is a mix of both. We are still looping on all assertions + // instead of retrieving value for one assertion, so we do not depend on dependency implementation. return fmt.Errorf("%s: %w", op, ErrInvalidSignature) } } - // validate root response attr + // validate root object response if !response.SignatureValidated { return fmt.Errorf("%s: %w", op, ErrInvalidSignature) } diff --git a/saml/response_test.go b/saml/response_test.go index e2cf3d9..63842bd 100644 --- a/saml/response_test.go +++ b/saml/response_test.go @@ -84,7 +84,7 @@ func TestServiceProvider_ParseResponse(t *testing.T) { { name: "error-invalid-signature", sp: testSp, - samlResp: base64.StdEncoding.EncodeToString([]byte(tp.SamlResponse(t, testprovider.WithJustResponseElemSigned()))), + samlResp: base64.StdEncoding.EncodeToString([]byte(tp.SamlResponse(t, testprovider.WithJustAssertionElemSigned()))), opts: []saml.Option{}, requestID: testRequestId, wantErrContains: "invalid signature", diff --git a/saml/test/provider.go b/saml/test/provider.go index 919785f..c411f68 100644 --- a/saml/test/provider.go +++ b/saml/test/provider.go @@ -561,20 +561,20 @@ func (p *TestProvider) SamlResponse(t *testing.T, opts ...ResponseOption) string if opt.signResponseElem || opt.signAssertionElem { signCtx := dsig.NewDefaultSigningContext(p.keystore) - // sign child attr assertions + // sign child object assertions if opt.signAssertionElem { responseEl := doc.SelectElement("Response") for _, assert := range responseEl.FindElements("Assertion") { signedAssert, err := signCtx.SignEnveloped(assert) r.NoError(err) - // replace signed assert element + // replace signed assert object responseEl.RemoveChildAt(assert.Index()) responseEl.AddChild(signedAssert) } } - // sign root attr response + // sign root object response if opt.signResponseElem { signed, err := signCtx.SignEnveloped(doc.Root()) r.NoError(err)