Skip to content

Commit

Permalink
saml: minor code improvements (#101)
Browse files Browse the repository at this point in the history
* fix (saml): address possible panic if clock.Clock is nil

* fix (saml): fix possible panic in WithAdditionalACSEndpoint(...)

changed  location url to be passed by value to eliminate possible
panic

* refactor (saml): add WithMetadataNameIDFormat(...)

Refactor WithAdditionalNameIDFormat(...) and WithNameIDFormats(...)
into one new option WithMetadataNameIDFormat(...)

* fix (saml): address possible panics in saml handlers

* tests (saml): minor code improvements
  • Loading branch information
jimlambrt authored Sep 9, 2023
1 parent d255ea8 commit 3a603e1
Show file tree
Hide file tree
Showing 20 changed files with 214 additions and 107 deletions.
2 changes: 1 addition & 1 deletion saml/authn_request.go
Original file line number Diff line number Diff line change
Expand Up @@ -249,7 +249,7 @@ func (sp *ServiceProvider) AuthnRequestPost(
b64Payload := base64.StdEncoding.EncodeToString(payload)

tmpl := template.Must(
template.New("post-binding").Parse(PostBindingTempl),
template.New("post-binding").Parse(postBindingTempl),
)

buf := bytes.Buffer{}
Expand Down
43 changes: 23 additions & 20 deletions saml/authn_request_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ import (
)

func Test_CreateAuthnRequest(t *testing.T) {
t.Parallel()
r := require.New(t)

tp := testprovider.StartTestProvider(t)
Expand Down Expand Up @@ -87,32 +88,33 @@ func Test_CreateAuthnRequest(t *testing.T) {
if c.err != "" {
r.Error(err)
r.ErrorContains(err, c.err)
} else {
r.NoError(err)

switch c.binding {
case core.ServiceBindingHTTPPost:
loc := fmt.Sprintf("%s/saml/login/post", tp.ServerURL())
r.Equal(loc, got.Destination)
case core.ServiceBindingHTTPRedirect:
loc := fmt.Sprintf("%s/saml/login/redirect", tp.ServerURL())
r.Equal(loc, got.Destination)
}

r.Equal(c.id, got.ID)
r.Equal("2.0", got.Version)
r.Equal(core.ServiceBindingHTTPPost, got.ProtocolBinding)
r.Equal(c.expectedACS, got.AssertionConsumerServiceURL)
r.Equal("http://test.me/entity", got.Issuer.Value)
r.Nil(got.NameIDPolicy)
r.Nil(got.RequestedAuthContext)
r.False(got.ForceAuthn)
return
}
r.NoError(err)

switch c.binding {
case core.ServiceBindingHTTPPost:
loc := fmt.Sprintf("%s/saml/login/post", tp.ServerURL())
r.Equal(loc, got.Destination)
case core.ServiceBindingHTTPRedirect:
loc := fmt.Sprintf("%s/saml/login/redirect", tp.ServerURL())
r.Equal(loc, got.Destination)
}

r.Equal(c.id, got.ID)
r.Equal("2.0", got.Version)
r.Equal(core.ServiceBindingHTTPPost, got.ProtocolBinding)
r.Equal(c.expectedACS, got.AssertionConsumerServiceURL)
r.Equal("http://test.me/entity", got.Issuer.Value)
r.Nil(got.NameIDPolicy)
r.Nil(got.RequestedAuthContext)
r.False(got.ForceAuthn)
})
}
}

func Test_CreateAuthnRequest_Options(t *testing.T) {
t.Parallel()
r := require.New(t)

tp := testprovider.StartTestProvider(t)
Expand Down Expand Up @@ -216,6 +218,7 @@ func Test_CreateAuthnRequest_Options(t *testing.T) {
}

func Test_ServiceProvider_AuthnRequestRedirect(t *testing.T) {
t.Parallel()
r := require.New(t)

tp := testprovider.StartTestProvider(t)
Expand Down
8 changes: 4 additions & 4 deletions saml/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -232,7 +232,7 @@ type configOptions struct {

func configOptionsDefault() configOptions {
return configOptions{
withValidUntil: DefaultValidUntil,
withValidUntil: defaultValidUntil,
}
}

Expand All @@ -245,7 +245,7 @@ func getConfigOptions(opt ...Option) configOptions {
opts.withGenerateAuthRequestID = DefaultGenerateAuthRequestID
}
if opts.withValidUntil == nil {
opts.withValidUntil = DefaultValidUntil
opts.withValidUntil = defaultValidUntil
}

return opts
Expand All @@ -264,8 +264,8 @@ func DefaultGenerateAuthRequestID() (string, error) {
return fmt.Sprintf("_%s", newID), nil
}

// DefaultValidUntil returns a timestamp with one year
// defaultValidUntil returns a timestamp with one year
// added to the time when this function is called.
func DefaultValidUntil() time.Time {
func defaultValidUntil() time.Time {
return time.Now().Add(time.Hour * 24 * 365)
}
27 changes: 15 additions & 12 deletions saml/config_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,12 @@ import (
)

func Test_NewConfig(t *testing.T) {
entityID := "http://test.me/entity"
acs := "http://test.me/sso/acs"
metadata := "http://test.me/sso/metadata"

t.Parallel()
const (
entityID = "http://test.me/entity"
acs = "http://test.me/sso/acs"
metadata = "http://test.me/sso/metadata"
)
cases := []struct {
name string
entityID string
Expand Down Expand Up @@ -62,21 +64,22 @@ func Test_NewConfig(t *testing.T) {

if c.expectedErr != "" {
r.ErrorContains(err, c.expectedErr)
} else {
r.NoError(err)
return
}
r.NoError(err)

r.Equal(got.EntityID, "http://test.me/entity")
r.Equal(got.AssertionConsumerServiceURL, "http://test.me/sso/acs")
r.Equal(got.MetadataURL, "http://test.me/sso/metadata")
r.Equal(got.EntityID, "http://test.me/entity")
r.Equal(got.AssertionConsumerServiceURL, "http://test.me/sso/acs")
r.Equal(got.MetadataURL, "http://test.me/sso/metadata")

r.NotNil(got.GenerateAuthRequestID)
r.NotNil(got.ValidUntil)
}
r.NotNil(got.GenerateAuthRequestID)
r.NotNil(got.ValidUntil)
})
}
}

func Test_GenerateAuthRequestID(t *testing.T) {
t.Parallel()
r := require.New(t)

id, err := saml.DefaultGenerateAuthRequestID()
Expand Down
20 changes: 16 additions & 4 deletions saml/demo/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,22 @@ func main() {
sp, err := saml.NewServiceProvider(cfg)
exitOnError(err)

http.HandleFunc("/saml/acs", handler.ACSHandlerFunc(sp))
http.HandleFunc("/saml/auth/redirect", handler.RedirectBindingHandlerFunc(sp))
http.HandleFunc("/saml/auth/post", handler.PostBindingHandlerFunc(sp))
http.HandleFunc("/metadata", handler.MetadataHandlerFunc(sp))
acsHandler, err := handler.ACSHandlerFunc(sp)
exitOnError(err)

redirectHandler, err := handler.RedirectBindingHandlerFunc(sp)
exitOnError(err)

postBindHandler, err := handler.PostBindingHandlerFunc(sp)
exitOnError(err)

metadataHandler, err := handler.MetadataHandlerFunc(sp)
exitOnError(err)

http.HandleFunc("/saml/acs", acsHandler)
http.HandleFunc("/saml/auth/redirect", redirectHandler)
http.HandleFunc("/saml/auth/post", postBindHandler)
http.HandleFunc("/metadata", metadataHandler)
http.HandleFunc("/login", func(w http.ResponseWriter, _ *http.Request) {
ts, _ := template.New("sso").Parse(
`<html><form method="GET" action="/saml/auth/redirect"><button type="submit">Submit Redirect</button></form></html>
Expand Down
11 changes: 9 additions & 2 deletions saml/handler/acs.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,14 @@ import (
"github.com/hashicorp/cap/saml"
)

func ACSHandlerFunc(sp *saml.ServiceProvider) http.HandlerFunc {
// ACSHandlerFunc creates a handler function that handles a SAML
// ACS request
func ACSHandlerFunc(sp *saml.ServiceProvider) (http.HandlerFunc, error) {
const op = "handler.ACSHandler"
switch {
case sp == nil:
return nil, fmt.Errorf("%s: missing service provider", op)
}
return func(w http.ResponseWriter, r *http.Request) {
r.ParseForm()
samlResp := r.PostForm.Get("SAMLResponse")
Expand All @@ -20,5 +27,5 @@ func ACSHandlerFunc(sp *saml.ServiceProvider) http.HandlerFunc {
}

fmt.Fprintf(w, "Authenticated! %+v", res)
}
}, nil
}
12 changes: 10 additions & 2 deletions saml/handler/metadata.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,25 @@ package handler

import (
"encoding/xml"
"fmt"
"net/http"

"github.com/hashicorp/cap/saml"
)

func MetadataHandlerFunc(sp *saml.ServiceProvider) http.HandlerFunc {
// MetadataHandlerFunc creates a handler function that handles a SAML
// metadata request
func MetadataHandlerFunc(sp *saml.ServiceProvider) (http.HandlerFunc, error) {
const op = "handler.MetadataHandlerFunc"
switch {
case sp == nil:
return nil, fmt.Errorf("%s: missing service provider", op)
}
return func(w http.ResponseWriter, _ *http.Request) {
meta := sp.CreateMetadata()
err := xml.NewEncoder(w).Encode(meta)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
}
}
}, nil
}
9 changes: 7 additions & 2 deletions saml/handler/post_binding.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,12 @@ import (
)

// PostBindingHandlerFunc creates a handler function that handles a HTTP-POST binding SAML request.
func PostBindingHandlerFunc(sp *saml.ServiceProvider) http.HandlerFunc {
func PostBindingHandlerFunc(sp *saml.ServiceProvider) (http.HandlerFunc, error) {
const op = "handler.PostBindingHandlerFunc"
switch {
case sp == nil:
return nil, fmt.Errorf("%s: missing service provider", op)
}
return func(w http.ResponseWriter, _ *http.Request) {
templ, _, err := sp.AuthnRequestPost("")
if err != nil {
Expand All @@ -33,5 +38,5 @@ func PostBindingHandlerFunc(sp *saml.ServiceProvider) http.HandlerFunc {
)
return
}
}
}, nil
}
11 changes: 9 additions & 2 deletions saml/handler/redirect_binding.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,14 @@ import (
"github.com/hashicorp/cap/saml"
)

func RedirectBindingHandlerFunc(sp *saml.ServiceProvider) http.HandlerFunc {
// RedirectBindingHandlerFunc creates a handler function that handles a SAML
// redirect request.
func RedirectBindingHandlerFunc(sp *saml.ServiceProvider) (http.HandlerFunc, error) {
const op = "handler.RedirectBindingHandlerFunc"
switch {
case sp == nil:
return nil, fmt.Errorf("%s: missing service provider", op)
}
return func(w http.ResponseWriter, r *http.Request) {
redirectURL, _, err := sp.AuthnRequestRedirect("relayState")
if err != nil {
Expand All @@ -24,5 +31,5 @@ func RedirectBindingHandlerFunc(sp *saml.ServiceProvider) http.HandlerFunc {
fmt.Printf("Redirect URL: %s\n", redirect)

http.Redirect(w, r, redirect, http.StatusFound)
}
}, nil
}
1 change: 0 additions & 1 deletion saml/internal/test/context.go

This file was deleted.

15 changes: 15 additions & 0 deletions saml/is_nil.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
package saml

import "reflect"

// isNil reports if a is nil
func isNil(a any) bool {
if a == nil {
return true
}
switch reflect.TypeOf(a).Kind() {
case reflect.Ptr, reflect.Map, reflect.Chan, reflect.Slice, reflect.Func:
return reflect.ValueOf(a).IsNil()
}
return false
}
56 changes: 56 additions & 0 deletions saml/is_nil_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
package saml

import (
"fmt"
"testing"

"github.com/stretchr/testify/assert"
)

func Test_isNil(t *testing.T) {
t.Parallel()

var testErrNilPtr *testError
var testMapNilPtr map[string]struct{}
var testArrayNilPtr *[1]string
var testChanNilPtr *chan string
var testSliceNilPtr *[]string
var testFuncNil func()

var testChanString chan string

tc := []struct {
i any
want bool
}{
{i: &testError{}, want: false},
{i: testError{}, want: false},
{i: &map[string]struct{}{}, want: false},
{i: map[string]struct{}{}, want: false},
{i: [1]string{}, want: false},
{i: &[1]string{}, want: false},
{i: &testChanString, want: false},
{i: "string", want: false},
{i: []string{}, want: false},
{i: func() {}, want: false},
{i: nil, want: true},
{i: testErrNilPtr, want: true},
{i: testMapNilPtr, want: true},
{i: testArrayNilPtr, want: true},
{i: testChanNilPtr, want: true},
{i: testChanString, want: true},
{i: testSliceNilPtr, want: true},
{i: testFuncNil, want: true},
}

for i, tc := range tc {
t.Run(fmt.Sprintf("test #%d", i+1), func(t *testing.T) {
assert := assert.New(t)
assert.Equal(tc.want, isNil(tc.i))
})
}
}

type testError struct{}

func (*testError) Error() string { return "error" }
Loading

0 comments on commit 3a603e1

Please sign in to comment.