From 8bfa53bdd9e64716c44f2e50866d27834d132d60 Mon Sep 17 00:00:00 2001 From: Jim Date: Sat, 9 Sep 2023 10:58:26 -0400 Subject: [PATCH 1/6] fix (saml): address possible panic if clock.Clock is nil --- saml/is_nil.go | 15 ++++++++++++ saml/is_nil_test.go | 56 +++++++++++++++++++++++++++++++++++++++++++++ saml/response.go | 4 ++++ 3 files changed, 75 insertions(+) create mode 100644 saml/is_nil.go create mode 100644 saml/is_nil_test.go diff --git a/saml/is_nil.go b/saml/is_nil.go new file mode 100644 index 0000000..c54654e --- /dev/null +++ b/saml/is_nil.go @@ -0,0 +1,15 @@ +package saml + +import "reflect" + +// isNil reports if a is nil +func isNil(a any) bool { + if a == nil { + return true + } + switch reflect.TypeOf(a).Kind() { + case reflect.Ptr, reflect.Map, reflect.Chan, reflect.Slice, reflect.Func: + return reflect.ValueOf(a).IsNil() + } + return false +} diff --git a/saml/is_nil_test.go b/saml/is_nil_test.go new file mode 100644 index 0000000..ff9fd40 --- /dev/null +++ b/saml/is_nil_test.go @@ -0,0 +1,56 @@ +package saml + +import ( + "fmt" + "testing" + + "github.com/stretchr/testify/assert" +) + +func Test_isNil(t *testing.T) { + t.Parallel() + + var testErrNilPtr *testError + var testMapNilPtr map[string]struct{} + var testArrayNilPtr *[1]string + var testChanNilPtr *chan string + var testSliceNilPtr *[]string + var testFuncNil func() + + var testChanString chan string + + tc := []struct { + i any + want bool + }{ + {i: &testError{}, want: false}, + {i: testError{}, want: false}, + {i: &map[string]struct{}{}, want: false}, + {i: map[string]struct{}{}, want: false}, + {i: [1]string{}, want: false}, + {i: &[1]string{}, want: false}, + {i: &testChanString, want: false}, + {i: "string", want: false}, + {i: []string{}, want: false}, + {i: func() {}, want: false}, + {i: nil, want: true}, + {i: testErrNilPtr, want: true}, + {i: testMapNilPtr, want: true}, + {i: testArrayNilPtr, want: true}, + {i: testChanNilPtr, want: true}, + {i: testChanString, want: true}, + {i: testSliceNilPtr, want: true}, + {i: testFuncNil, want: true}, + } + + for i, tc := range tc { + t.Run(fmt.Sprintf("test #%d", i+1), func(t *testing.T) { + assert := assert.New(t) + assert.Equal(tc.want, isNil(tc.i)) + }) + } +} + +type testError struct{} + +func (*testError) Error() string { return "error" } diff --git a/saml/response.go b/saml/response.go index 4df36a7..8eaf5a1 100644 --- a/saml/response.go +++ b/saml/response.go @@ -156,6 +156,10 @@ func (sp *ServiceProvider) internalParser( clock clockwork.Clock, ) (*saml2.SAMLServiceProvider, error) { const op = "saml.(ServiceProvider).internalParser" + switch { + case isNil(clock): + return nil, fmt.Errorf("%s: missing clock: %w", op, ErrInvalidParameter) + } idpMetadata, err := sp.IDPMetadata() if err != nil { return nil, fmt.Errorf("%s: %w", op, err) From 40f5fa0cd3f018778391459ffa9073488ff2cb78 Mon Sep 17 00:00:00 2001 From: Jim Date: Sat, 9 Sep 2023 10:43:35 -0400 Subject: [PATCH 2/6] fix (saml): fix possible panic in WithAdditionalACSEndpoint(...) changed location url to be passed by value to eliminate possible panic --- saml/sp.go | 3 ++- saml/sp_test.go | 2 +- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/saml/sp.go b/saml/sp.go index 6002463..2d7c647 100644 --- a/saml/sp.go +++ b/saml/sp.go @@ -69,7 +69,8 @@ func WithACSServiceBinding(b core.ServiceBinding) Option { } } -func WithAdditionalACSEndpoint(b core.ServiceBinding, location *url.URL) Option { +// WithAdditionalACSEndpoint provides an optional additional ACS endpoint +func WithAdditionalACSEndpoint(b core.ServiceBinding, location url.URL) Option { return func(o interface{}) { if o, ok := o.(*metadataOptions); ok { o.additionalACSs = append(o.additionalACSs, metadata.Endpoint{ diff --git a/saml/sp_test.go b/saml/sp_test.go index 30ea97f..0da5530 100644 --- a/saml/sp_test.go +++ b/saml/sp_test.go @@ -253,7 +253,7 @@ func Test_CreateMetadata_Options(t *testing.T) { got := provider.CreateMetadata( saml.WithAdditionalACSEndpoint( core.ServiceBindingHTTPRedirect, - redirectEndpoint, + *redirectEndpoint, ), ) From 3b302b0f2c12d33d50e313184838b0f76514f76c Mon Sep 17 00:00:00 2001 From: Jim Date: Sat, 9 Sep 2023 10:35:49 -0400 Subject: [PATCH 3/6] refactor (saml): add WithMetadataNameIDFormat(...) Refactor WithAdditionalNameIDFormat(...) and WithNameIDFormats(...) into one new option WithMetadataNameIDFormat(...) --- saml/sp.go | 14 ++++---------- saml/sp_test.go | 9 +++------ 2 files changed, 7 insertions(+), 16 deletions(-) diff --git a/saml/sp.go b/saml/sp.go index 2d7c647..43c0aa5 100644 --- a/saml/sp.go +++ b/saml/sp.go @@ -45,18 +45,12 @@ func InsecureWantAssertionsUnsigned() Option { } } -func WithAdditionalNameIDFormat(format core.NameIDFormat) Option { +// WithMetadataNameIDFormat provides an optional name ID formats, which are +// added to the existing set. +func WithMetadataNameIDFormat(format ...core.NameIDFormat) Option { return func(o interface{}) { if o, ok := o.(*metadataOptions); ok { - o.nameIDFormats = append(o.nameIDFormats, format) - } - } -} - -func WithNameIDFormats(formats []core.NameIDFormat) Option { - return func(o interface{}) { - if o, ok := o.(*metadataOptions); ok { - o.nameIDFormats = formats + o.nameIDFormats = append(o.nameIDFormats, format...) } } } diff --git a/saml/sp_test.go b/saml/sp_test.go index 0da5530..57e7452 100644 --- a/saml/sp_test.go +++ b/saml/sp_test.go @@ -156,7 +156,7 @@ func Test_ServiceProvider_CreateMetadata(t *testing.T) { r := require.New(t) opts := []saml.Option{} if c.nameIDFormats != nil { - opts = append(opts, saml.WithNameIDFormats(c.nameIDFormats)) + opts = append(opts, saml.WithMetadataNameIDFormat(c.nameIDFormats...)) } got := provider.CreateMetadata(opts...) @@ -210,7 +210,7 @@ func Test_CreateMetadata_Options(t *testing.T) { t.Run("When option WithAdditionalNameIDFormat is set", func(t *testing.T) { r := require.New(t) got := provider.CreateMetadata( - saml.WithAdditionalNameIDFormat(core.NameIDFormatTransient), + saml.WithMetadataNameIDFormat(core.NameIDFormatTransient), ) r.Equal(got.SPSSODescriptor[0].NameIDFormat, []core.NameIDFormat{core.NameIDFormatTransient}) @@ -219,10 +219,7 @@ func Test_CreateMetadata_Options(t *testing.T) { t.Run("When option WithNameIDFormats is set", func(t *testing.T) { r := require.New(t) got := provider.CreateMetadata( - saml.WithNameIDFormats([]core.NameIDFormat{ - core.NameIDFormatEntity, - core.NameIDFormatUnspecified, - }), + saml.WithMetadataNameIDFormat(core.NameIDFormatEntity, core.NameIDFormatUnspecified), ) r.Len(got.SPSSODescriptor[0].NameIDFormat, 2) From 26c0e39a4c890499b6810ccb8c530e98225101be Mon Sep 17 00:00:00 2001 From: Jim Date: Sat, 9 Sep 2023 12:41:14 -0400 Subject: [PATCH 4/6] fix (saml): address possible panics in saml handlers --- saml/demo/main.go | 20 ++++++++++++++++---- saml/handler/acs.go | 11 +++++++++-- saml/handler/metadata.go | 12 ++++++++++-- saml/handler/post_binding.go | 9 +++++++-- saml/handler/redirect_binding.go | 11 +++++++++-- 5 files changed, 51 insertions(+), 12 deletions(-) diff --git a/saml/demo/main.go b/saml/demo/main.go index 55f301e..99f99fb 100644 --- a/saml/demo/main.go +++ b/saml/demo/main.go @@ -29,10 +29,22 @@ func main() { sp, err := saml.NewServiceProvider(cfg) exitOnError(err) - http.HandleFunc("/saml/acs", handler.ACSHandlerFunc(sp)) - http.HandleFunc("/saml/auth/redirect", handler.RedirectBindingHandlerFunc(sp)) - http.HandleFunc("/saml/auth/post", handler.PostBindingHandlerFunc(sp)) - http.HandleFunc("/metadata", handler.MetadataHandlerFunc(sp)) + acsHandler, err := handler.ACSHandlerFunc(sp) + exitOnError(err) + + redirectHandler, err := handler.RedirectBindingHandlerFunc(sp) + exitOnError(err) + + postBindHandler, err := handler.PostBindingHandlerFunc(sp) + exitOnError(err) + + metadataHandler, err := handler.MetadataHandlerFunc(sp) + exitOnError(err) + + http.HandleFunc("/saml/acs", acsHandler) + http.HandleFunc("/saml/auth/redirect", redirectHandler) + http.HandleFunc("/saml/auth/post", postBindHandler) + http.HandleFunc("/metadata", metadataHandler) http.HandleFunc("/login", func(w http.ResponseWriter, _ *http.Request) { ts, _ := template.New("sso").Parse( `
diff --git a/saml/handler/acs.go b/saml/handler/acs.go index 3f12e34..d0693e2 100644 --- a/saml/handler/acs.go +++ b/saml/handler/acs.go @@ -7,7 +7,14 @@ import ( "github.com/hashicorp/cap/saml" ) -func ACSHandlerFunc(sp *saml.ServiceProvider) http.HandlerFunc { +// ACSHandlerFunc creates a handler function that handles a SAML +// ACS request +func ACSHandlerFunc(sp *saml.ServiceProvider) (http.HandlerFunc, error) { + const op = "handler.ACSHandler" + switch { + case sp == nil: + return nil, fmt.Errorf("%s: missing service provider", op) + } return func(w http.ResponseWriter, r *http.Request) { r.ParseForm() samlResp := r.PostForm.Get("SAMLResponse") @@ -20,5 +27,5 @@ func ACSHandlerFunc(sp *saml.ServiceProvider) http.HandlerFunc { } fmt.Fprintf(w, "Authenticated! %+v", res) - } + }, nil } diff --git a/saml/handler/metadata.go b/saml/handler/metadata.go index 58c6f5d..324b28b 100644 --- a/saml/handler/metadata.go +++ b/saml/handler/metadata.go @@ -2,17 +2,25 @@ package handler import ( "encoding/xml" + "fmt" "net/http" "github.com/hashicorp/cap/saml" ) -func MetadataHandlerFunc(sp *saml.ServiceProvider) http.HandlerFunc { +// MetadataHandlerFunc creates a handler function that handles a SAML +// metadata request +func MetadataHandlerFunc(sp *saml.ServiceProvider) (http.HandlerFunc, error) { + const op = "handler.MetadataHandlerFunc" + switch { + case sp == nil: + return nil, fmt.Errorf("%s: missing service provider", op) + } return func(w http.ResponseWriter, _ *http.Request) { meta := sp.CreateMetadata() err := xml.NewEncoder(w).Encode(meta) if err != nil { http.Error(w, err.Error(), http.StatusInternalServerError) } - } + }, nil } diff --git a/saml/handler/post_binding.go b/saml/handler/post_binding.go index 699099a..540b110 100644 --- a/saml/handler/post_binding.go +++ b/saml/handler/post_binding.go @@ -9,7 +9,12 @@ import ( ) // PostBindingHandlerFunc creates a handler function that handles a HTTP-POST binding SAML request. -func PostBindingHandlerFunc(sp *saml.ServiceProvider) http.HandlerFunc { +func PostBindingHandlerFunc(sp *saml.ServiceProvider) (http.HandlerFunc, error) { + const op = "handler.PostBindingHandlerFunc" + switch { + case sp == nil: + return nil, fmt.Errorf("%s: missing service provider", op) + } return func(w http.ResponseWriter, _ *http.Request) { templ, _, err := sp.AuthnRequestPost("") if err != nil { @@ -33,5 +38,5 @@ func PostBindingHandlerFunc(sp *saml.ServiceProvider) http.HandlerFunc { ) return } - } + }, nil } diff --git a/saml/handler/redirect_binding.go b/saml/handler/redirect_binding.go index dee5056..9b95ed1 100644 --- a/saml/handler/redirect_binding.go +++ b/saml/handler/redirect_binding.go @@ -7,7 +7,14 @@ import ( "github.com/hashicorp/cap/saml" ) -func RedirectBindingHandlerFunc(sp *saml.ServiceProvider) http.HandlerFunc { +// RedirectBindingHandlerFunc creates a handler function that handles a SAML +// redirect request. +func RedirectBindingHandlerFunc(sp *saml.ServiceProvider) (http.HandlerFunc, error) { + const op = "handler.RedirectBindingHandlerFunc" + switch { + case sp == nil: + return nil, fmt.Errorf("%s: missing service provider", op) + } return func(w http.ResponseWriter, r *http.Request) { redirectURL, _, err := sp.AuthnRequestRedirect("relayState") if err != nil { @@ -24,5 +31,5 @@ func RedirectBindingHandlerFunc(sp *saml.ServiceProvider) http.HandlerFunc { fmt.Printf("Redirect URL: %s\n", redirect) http.Redirect(w, r, redirect, http.StatusFound) - } + }, nil } From 154b738462d21ed8db8c06ddea4b3a2ab9e7031f Mon Sep 17 00:00:00 2001 From: Jim Date: Sat, 9 Sep 2023 08:43:43 -0400 Subject: [PATCH 5/6] tests (saml): minor code improvements --- saml/authn_request.go | 2 +- saml/config.go | 8 +++--- saml/config_test.go | 27 ++++++++++--------- saml/internal/test/context.go | 1 - saml/models/core/response_test.go | 16 +++++------ saml/models/metadata/entity_descriptor.go | 2 -- .../metadata/idp_sso_descriptor_test.go | 6 +++++ .../models/metadata/sp_sso_descriptor_test.go | 25 +++++------------ saml/response.go | 2 +- saml/response_test.go | 5 ++-- saml/sp.go | 6 ++++- saml/sp_test.go | 12 ++++++--- 12 files changed, 56 insertions(+), 56 deletions(-) delete mode 100644 saml/internal/test/context.go diff --git a/saml/authn_request.go b/saml/authn_request.go index 3668387..5453dbd 100644 --- a/saml/authn_request.go +++ b/saml/authn_request.go @@ -249,7 +249,7 @@ func (sp *ServiceProvider) AuthnRequestPost( b64Payload := base64.StdEncoding.EncodeToString(payload) tmpl := template.Must( - template.New("post-binding").Parse(PostBindingTempl), + template.New("post-binding").Parse(postBindingTempl), ) buf := bytes.Buffer{} diff --git a/saml/config.go b/saml/config.go index 5594afa..689e504 100644 --- a/saml/config.go +++ b/saml/config.go @@ -232,7 +232,7 @@ type configOptions struct { func configOptionsDefault() configOptions { return configOptions{ - withValidUntil: DefaultValidUntil, + withValidUntil: defaultValidUntil, } } @@ -245,7 +245,7 @@ func getConfigOptions(opt ...Option) configOptions { opts.withGenerateAuthRequestID = DefaultGenerateAuthRequestID } if opts.withValidUntil == nil { - opts.withValidUntil = DefaultValidUntil + opts.withValidUntil = defaultValidUntil } return opts @@ -264,8 +264,8 @@ func DefaultGenerateAuthRequestID() (string, error) { return fmt.Sprintf("_%s", newID), nil } -// DefaultValidUntil returns a timestamp with one year +// defaultValidUntil returns a timestamp with one year // added to the time when this function is called. -func DefaultValidUntil() time.Time { +func defaultValidUntil() time.Time { return time.Now().Add(time.Hour * 24 * 365) } diff --git a/saml/config_test.go b/saml/config_test.go index bdaaa15..5c087f7 100644 --- a/saml/config_test.go +++ b/saml/config_test.go @@ -11,10 +11,12 @@ import ( ) func Test_NewConfig(t *testing.T) { - entityID := "http://test.me/entity" - acs := "http://test.me/sso/acs" - metadata := "http://test.me/sso/metadata" - + t.Parallel() + const ( + entityID = "http://test.me/entity" + acs = "http://test.me/sso/acs" + metadata = "http://test.me/sso/metadata" + ) cases := []struct { name string entityID string @@ -62,21 +64,22 @@ func Test_NewConfig(t *testing.T) { if c.expectedErr != "" { r.ErrorContains(err, c.expectedErr) - } else { - r.NoError(err) + return + } + r.NoError(err) - r.Equal(got.EntityID, "http://test.me/entity") - r.Equal(got.AssertionConsumerServiceURL, "http://test.me/sso/acs") - r.Equal(got.MetadataURL, "http://test.me/sso/metadata") + r.Equal(got.EntityID, "http://test.me/entity") + r.Equal(got.AssertionConsumerServiceURL, "http://test.me/sso/acs") + r.Equal(got.MetadataURL, "http://test.me/sso/metadata") - r.NotNil(got.GenerateAuthRequestID) - r.NotNil(got.ValidUntil) - } + r.NotNil(got.GenerateAuthRequestID) + r.NotNil(got.ValidUntil) }) } } func Test_GenerateAuthRequestID(t *testing.T) { + t.Parallel() r := require.New(t) id, err := saml.DefaultGenerateAuthRequestID() diff --git a/saml/internal/test/context.go b/saml/internal/test/context.go deleted file mode 100644 index a03e1d1..0000000 --- a/saml/internal/test/context.go +++ /dev/null @@ -1 +0,0 @@ -package context diff --git a/saml/models/core/response_test.go b/saml/models/core/response_test.go index 3651607..80adc3f 100644 --- a/saml/models/core/response_test.go +++ b/saml/models/core/response_test.go @@ -8,7 +8,7 @@ import ( "github.com/stretchr/testify/require" ) -var ResponseXMLSignature = ` +var responseXMLSignature = ` @@ -39,6 +39,7 @@ var responseXMLContainer = ` ` func Test_ParseResponse_ResponseContainer(t *testing.T) { + t.Parallel() r := require.New(t) res := responseXML(t, responseXMLContainer) @@ -56,6 +57,7 @@ var responseXMLIssuer = ` ` func Test_ParseResponse_Issuer(t *testing.T) { + t.Parallel() r := require.New(t) iss := responseXML(t, responseXMLIssuer).Issuer @@ -70,14 +72,6 @@ var responseXMLStatus = ` ` -// func Test_ParseResponse_Status(t *testing.T) { -// r := require.New(t) - -// status := responseXML(t, responseXMLStatus).Status - -// r.Equal(status.StatusCode.Value, core.StatusCodeSuccess) -// } - var responseXMLAssertion = ` @@ -85,6 +79,7 @@ var responseXMLAssertion = ` ` func Test_ParseResponse_Assertion(t *testing.T) { + t.Parallel() r := require.New(t) assert := responseXML(t, responseXMLAssertion).Assertions[0] @@ -103,6 +98,7 @@ var responseXMLAssertionIssuer = ` ` func Test_ParseResponse_Assertion_Issuer(t *testing.T) { + t.Parallel() r := require.New(t) iss := responseXML(t, responseXMLAssertionIssuer).Assertions[0].Issuer @@ -123,6 +119,7 @@ var responseXMLAssertionSubject = ` ` func Test_ParseResponse_Assertion_Subject(t *testing.T) { + t.Parallel() r := require.New(t) sub := responseXML(t, responseXMLAssertionSubject).Assertions[0].Subject @@ -187,6 +184,7 @@ var responseXMLAssertions = ` ` func responseXML(t *testing.T, ssoRes string) core.Response { + t.Parallel() t.Helper() r := require.New(t) diff --git a/saml/models/metadata/entity_descriptor.go b/saml/models/metadata/entity_descriptor.go index 0efb1ad..0e8bb64 100644 --- a/saml/models/metadata/entity_descriptor.go +++ b/saml/models/metadata/entity_descriptor.go @@ -119,8 +119,6 @@ type KeyDescriptor struct { type KeyInfo struct { dsig.KeyInfo KeyName string - // XMLName xml.Name `xml:"http://www.w3.org/2000/09/xmldsig# KeyInfo"` - // X509Data X509Data `xml:"X509Data"` } // EncyrptionMethod describes the encryption algorithm applied to the cipher data. diff --git a/saml/models/metadata/idp_sso_descriptor_test.go b/saml/models/metadata/idp_sso_descriptor_test.go index 61e45b3..1dc2018 100644 --- a/saml/models/metadata/idp_sso_descriptor_test.go +++ b/saml/models/metadata/idp_sso_descriptor_test.go @@ -47,6 +47,7 @@ var exampleIDPSSODescriptor = ` ` func Test_IDPSSODescriptor(t *testing.T) { + t.Parallel() r := require.New(t) ed := &metadata.EntityDescriptorIDPSSO{} @@ -74,6 +75,7 @@ var exampleIDPSSOKeyDescriptor = ` ` func Test_IDPSSODescriptor_KeyDescriptor(t *testing.T) { + t.Parallel() r := require.New(t) ed := &metadata.EntityDescriptorIDPSSO{} @@ -98,6 +100,7 @@ var exampleIDPSSODescriptorArtifactResolutionService = `` func Test_IDPSSODescriptor_ArtifactResolutionService(t *testing.T) { + t.Parallel() r := require.New(t) ed := &metadata.EntityDescriptorIDPSSO{} @@ -126,6 +129,7 @@ var exampleIDPSSODescriptorSLO = ` ` func Test_IDPSSODescriptor_SLO(t *testing.T) { + t.Parallel() r := require.New(t) ed := &metadata.EntityDescriptorIDPSSO{} @@ -155,6 +159,7 @@ var exampleIDPSSODescriptorSSO = ` ` func Test_IDPSSODescriptor_SSO(t *testing.T) { + t.Parallel() r := require.New(t) ed := &metadata.EntityDescriptorIDPSSO{} @@ -190,6 +195,7 @@ var exampleIDPSSODescriptorAttributes = ` ` func Test_IDPSSODescriptor_Attributes(t *testing.T) { + t.Parallel() r := require.New(t) ed := &metadata.EntityDescriptorIDPSSO{} diff --git a/saml/models/metadata/sp_sso_descriptor_test.go b/saml/models/metadata/sp_sso_descriptor_test.go index 3c372fd..b1dc0b6 100644 --- a/saml/models/metadata/sp_sso_descriptor_test.go +++ b/saml/models/metadata/sp_sso_descriptor_test.go @@ -56,6 +56,7 @@ var exampleSPSSODescriptor = `` func Test_SPSSODescriptor(t *testing.T) { + t.Parallel() r := require.New(t) ed := &metadata.EntityDescriptorSPSSO{} @@ -89,6 +90,7 @@ var exampleSLOService = `` func Test_SPSSODescriptor_SLOService(t *testing.T) { + t.Parallel() r := require.New(t) ed := &metadata.EntityDescriptorSPSSO{} @@ -129,6 +131,7 @@ var exampleNameIDService = `` func Test_SPSSODescriptor_ManageNameIDService(t *testing.T) { + t.Parallel() r := require.New(t) ed := &metadata.EntityDescriptorSPSSO{} @@ -160,6 +163,7 @@ var exampleNameIDFormats = `` func Test_SPSSODescriptor_NameIDFormats(t *testing.T) { + t.Parallel() r := require.New(t) ed := &metadata.EntityDescriptorSPSSO{} @@ -193,6 +197,7 @@ var exampleACS = `` func Test_SPSSODescriptor_ACS(t *testing.T) { + t.Parallel() r := require.New(t) ed := &metadata.EntityDescriptorSPSSO{} @@ -250,6 +255,7 @@ var exampleAttributeConsumingService = `By-Tor func Test_SPSSODescriptor_AttributeConsumingService(t *testing.T) { + t.Parallel() r := require.New(t) ed := &metadata.EntityDescriptorSPSSO{} @@ -332,22 +338,3 @@ x5Ql0ejivIJAYcMGUyA+/YwJg2FGoA== ` - -// func Test_SPSSODescriptor_KeyDescritpor(t *testing.T) { -// r := require.New(t) - -// ed := &metadata.EntityDescriptor{} - -// err := xml.Unmarshal([]byte(exampleKeyDescriptor), ed) -// r.NoError(err) - -// keyDescriptor := ed.SPSSODescriptor[0].KeyDescriptor - -// r.Len(keyDescriptor, 2) - -// r.Equal(keyDescriptor[0].Use, metadata.KeyTypeSigning) -// r.NotEmpty(keyDescriptor[0].KeyInfo.X509Data, "") - -// r.Equal(keyDescriptor[1].Use, metadata.KeyTypeEncryption) -// r.NotEmpty(keyDescriptor[1].KeyInfo.X509Data, "") -// } diff --git a/saml/response.go b/saml/response.go index 8eaf5a1..1af9a47 100644 --- a/saml/response.go +++ b/saml/response.go @@ -178,7 +178,7 @@ func (sp *ServiceProvider) internalParser( for _, xcert := range kd.KeyInfo.X509Data.X509Certificates { parsed, err := parseX509Certificate(xcert.Data) if err != nil { - return nil, err + return nil, fmt.Errorf("%s: unable to parse cert: %w", op, err) } certStore.Roots = append(certStore.Roots, parsed) // append works just fine with a nil slice } diff --git a/saml/response_test.go b/saml/response_test.go index 9997389..9685793 100644 --- a/saml/response_test.go +++ b/saml/response_test.go @@ -91,6 +91,7 @@ func TestServiceProvider_ParseResponse(t *testing.T) { } func TestServiceProvider_ParseResponseCustomACS(t *testing.T) { + t.Parallel() r := require.New(t) fakeTime, err := time.Parse("2006-01-02", "2015-07-15") @@ -151,9 +152,9 @@ func TestServiceProvider_ParseResponseCustomACS(t *testing.T) { ) if c.err == "" { require.NoError(t, err) - } else { - require.ErrorContains(t, err, c.err) + return } + require.ErrorContains(t, err, c.err) }) } diff --git a/saml/sp.go b/saml/sp.go index 43c0aa5..4b38b23 100644 --- a/saml/sp.go +++ b/saml/sp.go @@ -15,7 +15,7 @@ import ( ) //go:embed authn_request.gohtml -var PostBindingTempl string +var postBindingTempl string type metadataOptions struct { wantAssertionsSigned bool @@ -37,6 +37,8 @@ func getMetadataOptions(opt ...Option) metadataOptions { return opts } +// InsecureWantAssertionsUnsigned provides a way to optionally request that you +// want insecure/unsigned assertions. func InsecureWantAssertionsUnsigned() Option { return func(o interface{}) { if o, ok := o.(*metadataOptions); ok { @@ -55,6 +57,7 @@ func WithMetadataNameIDFormat(format ...core.NameIDFormat) Option { } } +// WithACSServiceBinding provides an optional service binding. func WithACSServiceBinding(b core.ServiceBinding) Option { return func(o interface{}) { if o, ok := o.(*metadataOptions); ok { @@ -75,6 +78,7 @@ func WithAdditionalACSEndpoint(b core.ServiceBinding, location url.URL) Option { } } +// ServiceProvider defines a type for service providers type ServiceProvider struct { cfg *Config } diff --git a/saml/sp_test.go b/saml/sp_test.go index 57e7452..f27939b 100644 --- a/saml/sp_test.go +++ b/saml/sp_test.go @@ -16,6 +16,7 @@ import ( ) func Test_NewServiceProvider(t *testing.T) { + t.Parallel() r := require.New(t) exampleURL := "http://test.me" @@ -56,16 +57,17 @@ func Test_NewServiceProvider(t *testing.T) { if c.err != "" { r.Error(err) r.ErrorContains(err, c.err) - } else { - r.NoError(err) - r.NotNil(got) - r.NotNil(got.Config()) + return } + r.NoError(err) + r.NotNil(got) + r.NotNil(got.Config()) }) } } func Test_ServiceProvider_FetchMetadata_ErrorCases(t *testing.T) { + t.Parallel() r := require.New(t) s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { @@ -117,6 +119,7 @@ func Test_ServiceProvider_FetchMetadata_ErrorCases(t *testing.T) { } func Test_ServiceProvider_CreateMetadata(t *testing.T) { + t.Parallel() r := require.New(t) entityID := "http://test.me/entity" @@ -185,6 +188,7 @@ func Test_ServiceProvider_CreateMetadata(t *testing.T) { } func Test_CreateMetadata_Options(t *testing.T) { + t.Parallel() r := require.New(t) fakeURL := "http://fake.test.url" From 8c84e747fb9592d3057f8a8b10844463606db631 Mon Sep 17 00:00:00 2001 From: Jim Date: Sat, 9 Sep 2023 12:54:28 -0400 Subject: [PATCH 6/6] fixup! tests (saml): minor code improvements --- saml/authn_request_test.go | 43 +++++++++++++++++-------------- saml/models/core/response_test.go | 2 -- 2 files changed, 23 insertions(+), 22 deletions(-) diff --git a/saml/authn_request_test.go b/saml/authn_request_test.go index 46e67ac..0d961d2 100644 --- a/saml/authn_request_test.go +++ b/saml/authn_request_test.go @@ -15,6 +15,7 @@ import ( ) func Test_CreateAuthnRequest(t *testing.T) { + t.Parallel() r := require.New(t) tp := testprovider.StartTestProvider(t) @@ -87,32 +88,33 @@ func Test_CreateAuthnRequest(t *testing.T) { if c.err != "" { r.Error(err) r.ErrorContains(err, c.err) - } else { - r.NoError(err) - - switch c.binding { - case core.ServiceBindingHTTPPost: - loc := fmt.Sprintf("%s/saml/login/post", tp.ServerURL()) - r.Equal(loc, got.Destination) - case core.ServiceBindingHTTPRedirect: - loc := fmt.Sprintf("%s/saml/login/redirect", tp.ServerURL()) - r.Equal(loc, got.Destination) - } - - r.Equal(c.id, got.ID) - r.Equal("2.0", got.Version) - r.Equal(core.ServiceBindingHTTPPost, got.ProtocolBinding) - r.Equal(c.expectedACS, got.AssertionConsumerServiceURL) - r.Equal("http://test.me/entity", got.Issuer.Value) - r.Nil(got.NameIDPolicy) - r.Nil(got.RequestedAuthContext) - r.False(got.ForceAuthn) + return } + r.NoError(err) + + switch c.binding { + case core.ServiceBindingHTTPPost: + loc := fmt.Sprintf("%s/saml/login/post", tp.ServerURL()) + r.Equal(loc, got.Destination) + case core.ServiceBindingHTTPRedirect: + loc := fmt.Sprintf("%s/saml/login/redirect", tp.ServerURL()) + r.Equal(loc, got.Destination) + } + + r.Equal(c.id, got.ID) + r.Equal("2.0", got.Version) + r.Equal(core.ServiceBindingHTTPPost, got.ProtocolBinding) + r.Equal(c.expectedACS, got.AssertionConsumerServiceURL) + r.Equal("http://test.me/entity", got.Issuer.Value) + r.Nil(got.NameIDPolicy) + r.Nil(got.RequestedAuthContext) + r.False(got.ForceAuthn) }) } } func Test_CreateAuthnRequest_Options(t *testing.T) { + t.Parallel() r := require.New(t) tp := testprovider.StartTestProvider(t) @@ -216,6 +218,7 @@ func Test_CreateAuthnRequest_Options(t *testing.T) { } func Test_ServiceProvider_AuthnRequestRedirect(t *testing.T) { + t.Parallel() r := require.New(t) tp := testprovider.StartTestProvider(t) diff --git a/saml/models/core/response_test.go b/saml/models/core/response_test.go index 80adc3f..f1dce5c 100644 --- a/saml/models/core/response_test.go +++ b/saml/models/core/response_test.go @@ -46,7 +46,6 @@ func Test_ParseResponse_ResponseContainer(t *testing.T) { r.Equal(res.Destination, "http://localhost:8000/saml/acs") r.Equal(res.ID, "saml-response-id") - // r.Equal(res.InResponseTo, "saml-request-id") r.Equal(res.IssueInstant.String(), "2023-03-31 06:55:44.494 +0000 UTC") r.Equal(res.Version, "2.0") } @@ -184,7 +183,6 @@ var responseXMLAssertions = ` ` func responseXML(t *testing.T, ssoRes string) core.Response { - t.Parallel() t.Helper() r := require.New(t)