diff --git a/saml/authn_request.go b/saml/authn_request.go index 3668387..5453dbd 100644 --- a/saml/authn_request.go +++ b/saml/authn_request.go @@ -249,7 +249,7 @@ func (sp *ServiceProvider) AuthnRequestPost( b64Payload := base64.StdEncoding.EncodeToString(payload) tmpl := template.Must( - template.New("post-binding").Parse(PostBindingTempl), + template.New("post-binding").Parse(postBindingTempl), ) buf := bytes.Buffer{} diff --git a/saml/authn_request_test.go b/saml/authn_request_test.go index 46e67ac..0d961d2 100644 --- a/saml/authn_request_test.go +++ b/saml/authn_request_test.go @@ -15,6 +15,7 @@ import ( ) func Test_CreateAuthnRequest(t *testing.T) { + t.Parallel() r := require.New(t) tp := testprovider.StartTestProvider(t) @@ -87,32 +88,33 @@ func Test_CreateAuthnRequest(t *testing.T) { if c.err != "" { r.Error(err) r.ErrorContains(err, c.err) - } else { - r.NoError(err) - - switch c.binding { - case core.ServiceBindingHTTPPost: - loc := fmt.Sprintf("%s/saml/login/post", tp.ServerURL()) - r.Equal(loc, got.Destination) - case core.ServiceBindingHTTPRedirect: - loc := fmt.Sprintf("%s/saml/login/redirect", tp.ServerURL()) - r.Equal(loc, got.Destination) - } - - r.Equal(c.id, got.ID) - r.Equal("2.0", got.Version) - r.Equal(core.ServiceBindingHTTPPost, got.ProtocolBinding) - r.Equal(c.expectedACS, got.AssertionConsumerServiceURL) - r.Equal("http://test.me/entity", got.Issuer.Value) - r.Nil(got.NameIDPolicy) - r.Nil(got.RequestedAuthContext) - r.False(got.ForceAuthn) + return } + r.NoError(err) + + switch c.binding { + case core.ServiceBindingHTTPPost: + loc := fmt.Sprintf("%s/saml/login/post", tp.ServerURL()) + r.Equal(loc, got.Destination) + case core.ServiceBindingHTTPRedirect: + loc := fmt.Sprintf("%s/saml/login/redirect", tp.ServerURL()) + r.Equal(loc, got.Destination) + } + + r.Equal(c.id, got.ID) + r.Equal("2.0", got.Version) + r.Equal(core.ServiceBindingHTTPPost, got.ProtocolBinding) + r.Equal(c.expectedACS, got.AssertionConsumerServiceURL) + r.Equal("http://test.me/entity", got.Issuer.Value) + r.Nil(got.NameIDPolicy) + r.Nil(got.RequestedAuthContext) + r.False(got.ForceAuthn) }) } } func Test_CreateAuthnRequest_Options(t *testing.T) { + t.Parallel() r := require.New(t) tp := testprovider.StartTestProvider(t) @@ -216,6 +218,7 @@ func Test_CreateAuthnRequest_Options(t *testing.T) { } func Test_ServiceProvider_AuthnRequestRedirect(t *testing.T) { + t.Parallel() r := require.New(t) tp := testprovider.StartTestProvider(t) diff --git a/saml/config.go b/saml/config.go index 5594afa..689e504 100644 --- a/saml/config.go +++ b/saml/config.go @@ -232,7 +232,7 @@ type configOptions struct { func configOptionsDefault() configOptions { return configOptions{ - withValidUntil: DefaultValidUntil, + withValidUntil: defaultValidUntil, } } @@ -245,7 +245,7 @@ func getConfigOptions(opt ...Option) configOptions { opts.withGenerateAuthRequestID = DefaultGenerateAuthRequestID } if opts.withValidUntil == nil { - opts.withValidUntil = DefaultValidUntil + opts.withValidUntil = defaultValidUntil } return opts @@ -264,8 +264,8 @@ func DefaultGenerateAuthRequestID() (string, error) { return fmt.Sprintf("_%s", newID), nil } -// DefaultValidUntil returns a timestamp with one year +// defaultValidUntil returns a timestamp with one year // added to the time when this function is called. -func DefaultValidUntil() time.Time { +func defaultValidUntil() time.Time { return time.Now().Add(time.Hour * 24 * 365) } diff --git a/saml/config_test.go b/saml/config_test.go index bdaaa15..5c087f7 100644 --- a/saml/config_test.go +++ b/saml/config_test.go @@ -11,10 +11,12 @@ import ( ) func Test_NewConfig(t *testing.T) { - entityID := "http://test.me/entity" - acs := "http://test.me/sso/acs" - metadata := "http://test.me/sso/metadata" - + t.Parallel() + const ( + entityID = "http://test.me/entity" + acs = "http://test.me/sso/acs" + metadata = "http://test.me/sso/metadata" + ) cases := []struct { name string entityID string @@ -62,21 +64,22 @@ func Test_NewConfig(t *testing.T) { if c.expectedErr != "" { r.ErrorContains(err, c.expectedErr) - } else { - r.NoError(err) + return + } + r.NoError(err) - r.Equal(got.EntityID, "http://test.me/entity") - r.Equal(got.AssertionConsumerServiceURL, "http://test.me/sso/acs") - r.Equal(got.MetadataURL, "http://test.me/sso/metadata") + r.Equal(got.EntityID, "http://test.me/entity") + r.Equal(got.AssertionConsumerServiceURL, "http://test.me/sso/acs") + r.Equal(got.MetadataURL, "http://test.me/sso/metadata") - r.NotNil(got.GenerateAuthRequestID) - r.NotNil(got.ValidUntil) - } + r.NotNil(got.GenerateAuthRequestID) + r.NotNil(got.ValidUntil) }) } } func Test_GenerateAuthRequestID(t *testing.T) { + t.Parallel() r := require.New(t) id, err := saml.DefaultGenerateAuthRequestID() diff --git a/saml/demo/main.go b/saml/demo/main.go index 55f301e..99f99fb 100644 --- a/saml/demo/main.go +++ b/saml/demo/main.go @@ -29,10 +29,22 @@ func main() { sp, err := saml.NewServiceProvider(cfg) exitOnError(err) - http.HandleFunc("/saml/acs", handler.ACSHandlerFunc(sp)) - http.HandleFunc("/saml/auth/redirect", handler.RedirectBindingHandlerFunc(sp)) - http.HandleFunc("/saml/auth/post", handler.PostBindingHandlerFunc(sp)) - http.HandleFunc("/metadata", handler.MetadataHandlerFunc(sp)) + acsHandler, err := handler.ACSHandlerFunc(sp) + exitOnError(err) + + redirectHandler, err := handler.RedirectBindingHandlerFunc(sp) + exitOnError(err) + + postBindHandler, err := handler.PostBindingHandlerFunc(sp) + exitOnError(err) + + metadataHandler, err := handler.MetadataHandlerFunc(sp) + exitOnError(err) + + http.HandleFunc("/saml/acs", acsHandler) + http.HandleFunc("/saml/auth/redirect", redirectHandler) + http.HandleFunc("/saml/auth/post", postBindHandler) + http.HandleFunc("/metadata", metadataHandler) http.HandleFunc("/login", func(w http.ResponseWriter, _ *http.Request) { ts, _ := template.New("sso").Parse( `
diff --git a/saml/handler/acs.go b/saml/handler/acs.go index 3f12e34..d0693e2 100644 --- a/saml/handler/acs.go +++ b/saml/handler/acs.go @@ -7,7 +7,14 @@ import ( "github.com/hashicorp/cap/saml" ) -func ACSHandlerFunc(sp *saml.ServiceProvider) http.HandlerFunc { +// ACSHandlerFunc creates a handler function that handles a SAML +// ACS request +func ACSHandlerFunc(sp *saml.ServiceProvider) (http.HandlerFunc, error) { + const op = "handler.ACSHandler" + switch { + case sp == nil: + return nil, fmt.Errorf("%s: missing service provider", op) + } return func(w http.ResponseWriter, r *http.Request) { r.ParseForm() samlResp := r.PostForm.Get("SAMLResponse") @@ -20,5 +27,5 @@ func ACSHandlerFunc(sp *saml.ServiceProvider) http.HandlerFunc { } fmt.Fprintf(w, "Authenticated! %+v", res) - } + }, nil } diff --git a/saml/handler/metadata.go b/saml/handler/metadata.go index 58c6f5d..324b28b 100644 --- a/saml/handler/metadata.go +++ b/saml/handler/metadata.go @@ -2,17 +2,25 @@ package handler import ( "encoding/xml" + "fmt" "net/http" "github.com/hashicorp/cap/saml" ) -func MetadataHandlerFunc(sp *saml.ServiceProvider) http.HandlerFunc { +// MetadataHandlerFunc creates a handler function that handles a SAML +// metadata request +func MetadataHandlerFunc(sp *saml.ServiceProvider) (http.HandlerFunc, error) { + const op = "handler.MetadataHandlerFunc" + switch { + case sp == nil: + return nil, fmt.Errorf("%s: missing service provider", op) + } return func(w http.ResponseWriter, _ *http.Request) { meta := sp.CreateMetadata() err := xml.NewEncoder(w).Encode(meta) if err != nil { http.Error(w, err.Error(), http.StatusInternalServerError) } - } + }, nil } diff --git a/saml/handler/post_binding.go b/saml/handler/post_binding.go index 699099a..540b110 100644 --- a/saml/handler/post_binding.go +++ b/saml/handler/post_binding.go @@ -9,7 +9,12 @@ import ( ) // PostBindingHandlerFunc creates a handler function that handles a HTTP-POST binding SAML request. -func PostBindingHandlerFunc(sp *saml.ServiceProvider) http.HandlerFunc { +func PostBindingHandlerFunc(sp *saml.ServiceProvider) (http.HandlerFunc, error) { + const op = "handler.PostBindingHandlerFunc" + switch { + case sp == nil: + return nil, fmt.Errorf("%s: missing service provider", op) + } return func(w http.ResponseWriter, _ *http.Request) { templ, _, err := sp.AuthnRequestPost("") if err != nil { @@ -33,5 +38,5 @@ func PostBindingHandlerFunc(sp *saml.ServiceProvider) http.HandlerFunc { ) return } - } + }, nil } diff --git a/saml/handler/redirect_binding.go b/saml/handler/redirect_binding.go index dee5056..9b95ed1 100644 --- a/saml/handler/redirect_binding.go +++ b/saml/handler/redirect_binding.go @@ -7,7 +7,14 @@ import ( "github.com/hashicorp/cap/saml" ) -func RedirectBindingHandlerFunc(sp *saml.ServiceProvider) http.HandlerFunc { +// RedirectBindingHandlerFunc creates a handler function that handles a SAML +// redirect request. +func RedirectBindingHandlerFunc(sp *saml.ServiceProvider) (http.HandlerFunc, error) { + const op = "handler.RedirectBindingHandlerFunc" + switch { + case sp == nil: + return nil, fmt.Errorf("%s: missing service provider", op) + } return func(w http.ResponseWriter, r *http.Request) { redirectURL, _, err := sp.AuthnRequestRedirect("relayState") if err != nil { @@ -24,5 +31,5 @@ func RedirectBindingHandlerFunc(sp *saml.ServiceProvider) http.HandlerFunc { fmt.Printf("Redirect URL: %s\n", redirect) http.Redirect(w, r, redirect, http.StatusFound) - } + }, nil } diff --git a/saml/internal/test/context.go b/saml/internal/test/context.go deleted file mode 100644 index a03e1d1..0000000 --- a/saml/internal/test/context.go +++ /dev/null @@ -1 +0,0 @@ -package context diff --git a/saml/is_nil.go b/saml/is_nil.go new file mode 100644 index 0000000..c54654e --- /dev/null +++ b/saml/is_nil.go @@ -0,0 +1,15 @@ +package saml + +import "reflect" + +// isNil reports if a is nil +func isNil(a any) bool { + if a == nil { + return true + } + switch reflect.TypeOf(a).Kind() { + case reflect.Ptr, reflect.Map, reflect.Chan, reflect.Slice, reflect.Func: + return reflect.ValueOf(a).IsNil() + } + return false +} diff --git a/saml/is_nil_test.go b/saml/is_nil_test.go new file mode 100644 index 0000000..ff9fd40 --- /dev/null +++ b/saml/is_nil_test.go @@ -0,0 +1,56 @@ +package saml + +import ( + "fmt" + "testing" + + "github.com/stretchr/testify/assert" +) + +func Test_isNil(t *testing.T) { + t.Parallel() + + var testErrNilPtr *testError + var testMapNilPtr map[string]struct{} + var testArrayNilPtr *[1]string + var testChanNilPtr *chan string + var testSliceNilPtr *[]string + var testFuncNil func() + + var testChanString chan string + + tc := []struct { + i any + want bool + }{ + {i: &testError{}, want: false}, + {i: testError{}, want: false}, + {i: &map[string]struct{}{}, want: false}, + {i: map[string]struct{}{}, want: false}, + {i: [1]string{}, want: false}, + {i: &[1]string{}, want: false}, + {i: &testChanString, want: false}, + {i: "string", want: false}, + {i: []string{}, want: false}, + {i: func() {}, want: false}, + {i: nil, want: true}, + {i: testErrNilPtr, want: true}, + {i: testMapNilPtr, want: true}, + {i: testArrayNilPtr, want: true}, + {i: testChanNilPtr, want: true}, + {i: testChanString, want: true}, + {i: testSliceNilPtr, want: true}, + {i: testFuncNil, want: true}, + } + + for i, tc := range tc { + t.Run(fmt.Sprintf("test #%d", i+1), func(t *testing.T) { + assert := assert.New(t) + assert.Equal(tc.want, isNil(tc.i)) + }) + } +} + +type testError struct{} + +func (*testError) Error() string { return "error" } diff --git a/saml/models/core/response_test.go b/saml/models/core/response_test.go index 3651607..f1dce5c 100644 --- a/saml/models/core/response_test.go +++ b/saml/models/core/response_test.go @@ -8,7 +8,7 @@ import ( "github.com/stretchr/testify/require" ) -var ResponseXMLSignature = ` +var responseXMLSignature = `