Skip to content

Commit

Permalink
tests (saml): minor code improvements
Browse files Browse the repository at this point in the history
  • Loading branch information
jimlambrt committed Sep 9, 2023
1 parent 3b302b0 commit af43639
Show file tree
Hide file tree
Showing 6 changed files with 24 additions and 28 deletions.
2 changes: 1 addition & 1 deletion saml/authn_request.go
Original file line number Diff line number Diff line change
Expand Up @@ -249,7 +249,7 @@ func (sp *ServiceProvider) AuthnRequestPost(
b64Payload := base64.StdEncoding.EncodeToString(payload)

tmpl := template.Must(
template.New("post-binding").Parse(PostBindingTempl),
template.New("post-binding").Parse(postBindingTempl),
)

buf := bytes.Buffer{}
Expand Down
25 changes: 6 additions & 19 deletions saml/models/metadata/sp_sso_descriptor_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ var exampleSPSSODescriptor = `<EntityDescriptor
</EntityDescriptor>`

func Test_SPSSODescriptor(t *testing.T) {
t.Parallel()
r := require.New(t)

ed := &metadata.EntityDescriptorSPSSO{}
Expand Down Expand Up @@ -89,6 +90,7 @@ var exampleSLOService = `<EntityDescriptor
</EntityDescriptor>`

func Test_SPSSODescriptor_SLOService(t *testing.T) {
t.Parallel()
r := require.New(t)

ed := &metadata.EntityDescriptorSPSSO{}
Expand Down Expand Up @@ -129,6 +131,7 @@ var exampleNameIDService = `<EntityDescriptor
</EntityDescriptor>`

func Test_SPSSODescriptor_ManageNameIDService(t *testing.T) {
t.Parallel()
r := require.New(t)

ed := &metadata.EntityDescriptorSPSSO{}
Expand Down Expand Up @@ -160,6 +163,7 @@ var exampleNameIDFormats = `<EntityDescriptor
</EntityDescriptor>`

func Test_SPSSODescriptor_NameIDFormats(t *testing.T) {
t.Parallel()
r := require.New(t)

ed := &metadata.EntityDescriptorSPSSO{}
Expand Down Expand Up @@ -193,6 +197,7 @@ var exampleACS = `<EntityDescriptor
</EntityDescriptor>`

func Test_SPSSODescriptor_ACS(t *testing.T) {
t.Parallel()
r := require.New(t)

ed := &metadata.EntityDescriptorSPSSO{}
Expand Down Expand Up @@ -250,6 +255,7 @@ var exampleAttributeConsumingService = `<EntityDescriptor
// <saml:AttributeValue type="xs:string">By-Tor</saml:AttributeValue>

func Test_SPSSODescriptor_AttributeConsumingService(t *testing.T) {
t.Parallel()
r := require.New(t)

ed := &metadata.EntityDescriptorSPSSO{}
Expand Down Expand Up @@ -332,22 +338,3 @@ x5Ql0ejivIJAYcMGUyA+/YwJg2FGoA==
</KeyDescriptor>
</SPSSODescriptor>
</EntityDescriptor>`

// func Test_SPSSODescriptor_KeyDescritpor(t *testing.T) {
// r := require.New(t)

// ed := &metadata.EntityDescriptor{}

// err := xml.Unmarshal([]byte(exampleKeyDescriptor), ed)
// r.NoError(err)

// keyDescriptor := ed.SPSSODescriptor[0].KeyDescriptor

// r.Len(keyDescriptor, 2)

// r.Equal(keyDescriptor[0].Use, metadata.KeyTypeSigning)
// r.NotEmpty(keyDescriptor[0].KeyInfo.X509Data, "")

// r.Equal(keyDescriptor[1].Use, metadata.KeyTypeEncryption)
// r.NotEmpty(keyDescriptor[1].KeyInfo.X509Data, "")
// }
2 changes: 1 addition & 1 deletion saml/response.go
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,7 @@ func (sp *ServiceProvider) internalParser(
for _, xcert := range kd.KeyInfo.X509Data.X509Certificates {
parsed, err := parseX509Certificate(xcert.Data)
if err != nil {
return nil, err
return nil, fmt.Errorf("%s: unable to parse cert: %w", op, err)
}
certStore.Roots = append(certStore.Roots, parsed) // append works just fine with a nil slice
}
Expand Down
5 changes: 3 additions & 2 deletions saml/response_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@ func TestServiceProvider_ParseResponse(t *testing.T) {
}

func TestServiceProvider_ParseResponseCustomACS(t *testing.T) {
t.Parallel()
r := require.New(t)

fakeTime, err := time.Parse("2006-01-02", "2015-07-15")
Expand Down Expand Up @@ -151,9 +152,9 @@ func TestServiceProvider_ParseResponseCustomACS(t *testing.T) {
)
if c.err == "" {
require.NoError(t, err)
} else {
require.ErrorContains(t, err, c.err)
return
}
require.ErrorContains(t, err, c.err)
})
}

Expand Down
6 changes: 5 additions & 1 deletion saml/sp.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ import (
)

//go:embed authn_request.gohtml
var PostBindingTempl string
var postBindingTempl string

type metadataOptions struct {
wantAssertionsSigned bool
Expand All @@ -37,6 +37,8 @@ func getMetadataOptions(opt ...Option) metadataOptions {
return opts
}

// InsecureWantAssertionsUnsigned provides a way to optionally request that you
// want insecure/unsigned assertions.
func InsecureWantAssertionsUnsigned() Option {
return func(o interface{}) {
if o, ok := o.(*metadataOptions); ok {
Expand All @@ -55,6 +57,7 @@ func WithMetadataNameIDFormat(format ...core.NameIDFormat) Option {
}
}

// WithACSServiceBinding provides an optional service binding.
func WithACSServiceBinding(b core.ServiceBinding) Option {
return func(o interface{}) {
if o, ok := o.(*metadataOptions); ok {
Expand All @@ -75,6 +78,7 @@ func WithAdditionalACSEndpoint(b core.ServiceBinding, location url.URL) Option {
}
}

// ServiceProvider defines a type for service providers
type ServiceProvider struct {
cfg *Config
}
Expand Down
12 changes: 8 additions & 4 deletions saml/sp_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ import (
)

func Test_NewServiceProvider(t *testing.T) {
t.Parallel()
r := require.New(t)
exampleURL := "http://test.me"

Expand Down Expand Up @@ -56,16 +57,17 @@ func Test_NewServiceProvider(t *testing.T) {
if c.err != "" {
r.Error(err)
r.ErrorContains(err, c.err)
} else {
r.NoError(err)
r.NotNil(got)
r.NotNil(got.Config())
return
}
r.NoError(err)
r.NotNil(got)
r.NotNil(got.Config())
})
}
}

func Test_ServiceProvider_FetchMetadata_ErrorCases(t *testing.T) {
t.Parallel()
r := require.New(t)

s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
Expand Down Expand Up @@ -117,6 +119,7 @@ func Test_ServiceProvider_FetchMetadata_ErrorCases(t *testing.T) {
}

func Test_ServiceProvider_CreateMetadata(t *testing.T) {
t.Parallel()
r := require.New(t)

entityID := "http://test.me/entity"
Expand Down Expand Up @@ -185,6 +188,7 @@ func Test_ServiceProvider_CreateMetadata(t *testing.T) {
}

func Test_CreateMetadata_Options(t *testing.T) {
t.Parallel()
r := require.New(t)

fakeURL := "http://fake.test.url"
Expand Down

0 comments on commit af43639

Please sign in to comment.