Skip to content

Commit

Permalink
tests (saml): minor code improvements
Browse files Browse the repository at this point in the history
  • Loading branch information
jimlambrt committed Sep 9, 2023
1 parent 26c0e39 commit 154b738
Show file tree
Hide file tree
Showing 12 changed files with 56 additions and 56 deletions.
2 changes: 1 addition & 1 deletion saml/authn_request.go
Original file line number Diff line number Diff line change
Expand Up @@ -249,7 +249,7 @@ func (sp *ServiceProvider) AuthnRequestPost(
b64Payload := base64.StdEncoding.EncodeToString(payload)

tmpl := template.Must(
template.New("post-binding").Parse(PostBindingTempl),
template.New("post-binding").Parse(postBindingTempl),
)

buf := bytes.Buffer{}
Expand Down
8 changes: 4 additions & 4 deletions saml/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -232,7 +232,7 @@ type configOptions struct {

func configOptionsDefault() configOptions {
return configOptions{
withValidUntil: DefaultValidUntil,
withValidUntil: defaultValidUntil,
}
}

Expand All @@ -245,7 +245,7 @@ func getConfigOptions(opt ...Option) configOptions {
opts.withGenerateAuthRequestID = DefaultGenerateAuthRequestID
}
if opts.withValidUntil == nil {
opts.withValidUntil = DefaultValidUntil
opts.withValidUntil = defaultValidUntil
}

return opts
Expand All @@ -264,8 +264,8 @@ func DefaultGenerateAuthRequestID() (string, error) {
return fmt.Sprintf("_%s", newID), nil
}

// DefaultValidUntil returns a timestamp with one year
// defaultValidUntil returns a timestamp with one year
// added to the time when this function is called.
func DefaultValidUntil() time.Time {
func defaultValidUntil() time.Time {
return time.Now().Add(time.Hour * 24 * 365)
}
27 changes: 15 additions & 12 deletions saml/config_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,12 @@ import (
)

func Test_NewConfig(t *testing.T) {
entityID := "http://test.me/entity"
acs := "http://test.me/sso/acs"
metadata := "http://test.me/sso/metadata"

t.Parallel()
const (
entityID = "http://test.me/entity"
acs = "http://test.me/sso/acs"
metadata = "http://test.me/sso/metadata"
)
cases := []struct {
name string
entityID string
Expand Down Expand Up @@ -62,21 +64,22 @@ func Test_NewConfig(t *testing.T) {

if c.expectedErr != "" {
r.ErrorContains(err, c.expectedErr)
} else {
r.NoError(err)
return
}
r.NoError(err)

r.Equal(got.EntityID, "http://test.me/entity")
r.Equal(got.AssertionConsumerServiceURL, "http://test.me/sso/acs")
r.Equal(got.MetadataURL, "http://test.me/sso/metadata")
r.Equal(got.EntityID, "http://test.me/entity")
r.Equal(got.AssertionConsumerServiceURL, "http://test.me/sso/acs")
r.Equal(got.MetadataURL, "http://test.me/sso/metadata")

r.NotNil(got.GenerateAuthRequestID)
r.NotNil(got.ValidUntil)
}
r.NotNil(got.GenerateAuthRequestID)
r.NotNil(got.ValidUntil)
})
}
}

func Test_GenerateAuthRequestID(t *testing.T) {
t.Parallel()
r := require.New(t)

id, err := saml.DefaultGenerateAuthRequestID()
Expand Down
1 change: 0 additions & 1 deletion saml/internal/test/context.go

This file was deleted.

16 changes: 7 additions & 9 deletions saml/models/core/response_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ import (
"github.com/stretchr/testify/require"
)

var ResponseXMLSignature = `<?xml version="1.0" encoding="UTF-8"?>
var responseXMLSignature = `<?xml version="1.0" encoding="UTF-8"?>
<saml2p:Response xmlns:saml2p="urn:oasis:names:tc:SAML:2.0:protocol" xmlns:xsd="http://www.w3.org/2001/XMLSchema" Destination="http://localhost:8000/saml/acs" ID="saml-response-id" InResponseTo="saml-request-id" IssueInstant="2023-03-31T06:55:44.494Z" Version="2.0">
<ds:Signature xmlns:ds="http://www.w3.org/2000/09/xmldsig#">
<ds:SignedInfo>
Expand Down Expand Up @@ -39,6 +39,7 @@ var responseXMLContainer = `<?xml version="1.0" encoding="UTF-8"?>
</saml2p:Response>`

func Test_ParseResponse_ResponseContainer(t *testing.T) {
t.Parallel()
r := require.New(t)

res := responseXML(t, responseXMLContainer)
Expand All @@ -56,6 +57,7 @@ var responseXMLIssuer = `<?xml version="1.0" encoding="UTF-8"?>
</saml2p:Response>`

func Test_ParseResponse_Issuer(t *testing.T) {
t.Parallel()
r := require.New(t)

iss := responseXML(t, responseXMLIssuer).Issuer
Expand All @@ -70,21 +72,14 @@ var responseXMLStatus = `<?xml version="1.0" encoding="UTF-8"?>
</saml2p:Status>
</saml2p:Response>`

// func Test_ParseResponse_Status(t *testing.T) {
// r := require.New(t)

// status := responseXML(t, responseXMLStatus).Status

// r.Equal(status.StatusCode.Value, core.StatusCodeSuccess)
// }

var responseXMLAssertion = `<?xml version="1.0" encoding="UTF-8"?>
<saml2p:Response xmlns:saml2p="urn:oasis:names:tc:SAML:2.0:protocol" xmlns:xsd="http://www.w3.org/2001/XMLSchema" Destination="http://localhost:8000/saml/acs" ID="saml-response-id" InResponseTo="saml-request-id" IssueInstant="2023-03-31T06:55:44.494Z" Version="2.0">
<saml2:Assertion xmlns:saml2="urn:oasis:names:tc:SAML:2.0:assertion" ID="assertion-id" IssueInstant="2023-03-31T06:55:44.494Z" Version="2.0">
</saml2:Assertion>
</saml2p:Response>`

func Test_ParseResponse_Assertion(t *testing.T) {
t.Parallel()
r := require.New(t)

assert := responseXML(t, responseXMLAssertion).Assertions[0]
Expand All @@ -103,6 +98,7 @@ var responseXMLAssertionIssuer = `<?xml version="1.0" encoding="UTF-8"?>
</saml2p:Response>`

func Test_ParseResponse_Assertion_Issuer(t *testing.T) {
t.Parallel()
r := require.New(t)

iss := responseXML(t, responseXMLAssertionIssuer).Assertions[0].Issuer
Expand All @@ -123,6 +119,7 @@ var responseXMLAssertionSubject = `<?xml version="1.0" encoding="UTF-8"?>
</saml2p:Response>`

func Test_ParseResponse_Assertion_Subject(t *testing.T) {
t.Parallel()
r := require.New(t)

sub := responseXML(t, responseXMLAssertionSubject).Assertions[0].Subject
Expand Down Expand Up @@ -187,6 +184,7 @@ var responseXMLAssertions = `<?xml version="1.0" encoding="UTF-8"?>
</saml2p:Response>`

func responseXML(t *testing.T, ssoRes string) core.Response {
t.Parallel()
t.Helper()

r := require.New(t)
Expand Down
2 changes: 0 additions & 2 deletions saml/models/metadata/entity_descriptor.go
Original file line number Diff line number Diff line change
Expand Up @@ -119,8 +119,6 @@ type KeyDescriptor struct {
type KeyInfo struct {
dsig.KeyInfo
KeyName string
// XMLName xml.Name `xml:"http://www.w3.org/2000/09/xmldsig# KeyInfo"`
// X509Data X509Data `xml:"X509Data"`
}

// EncyrptionMethod describes the encryption algorithm applied to the cipher data.
Expand Down
6 changes: 6 additions & 0 deletions saml/models/metadata/idp_sso_descriptor_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ var exampleIDPSSODescriptor = `<?xml version="1.0" encoding="UTF-8"?>
</EntityDescriptor>`

func Test_IDPSSODescriptor(t *testing.T) {
t.Parallel()
r := require.New(t)

ed := &metadata.EntityDescriptorIDPSSO{}
Expand Down Expand Up @@ -74,6 +75,7 @@ var exampleIDPSSOKeyDescriptor = `<?xml version="1.0" encoding="UTF-8"?>
</EntityDescriptor>`

func Test_IDPSSODescriptor_KeyDescriptor(t *testing.T) {
t.Parallel()
r := require.New(t)

ed := &metadata.EntityDescriptorIDPSSO{}
Expand All @@ -98,6 +100,7 @@ var exampleIDPSSODescriptorArtifactResolutionService = `<?xml version="1.0" enco
</EntityDescriptor>`

func Test_IDPSSODescriptor_ArtifactResolutionService(t *testing.T) {
t.Parallel()
r := require.New(t)

ed := &metadata.EntityDescriptorIDPSSO{}
Expand Down Expand Up @@ -126,6 +129,7 @@ var exampleIDPSSODescriptorSLO = `<?xml version="1.0" encoding="UTF-8"?>
</EntityDescriptor>`

func Test_IDPSSODescriptor_SLO(t *testing.T) {
t.Parallel()
r := require.New(t)

ed := &metadata.EntityDescriptorIDPSSO{}
Expand Down Expand Up @@ -155,6 +159,7 @@ var exampleIDPSSODescriptorSSO = `<?xml version="1.0" encoding="UTF-8"?>
</EntityDescriptor>`

func Test_IDPSSODescriptor_SSO(t *testing.T) {
t.Parallel()
r := require.New(t)

ed := &metadata.EntityDescriptorIDPSSO{}
Expand Down Expand Up @@ -190,6 +195,7 @@ var exampleIDPSSODescriptorAttributes = `<?xml version="1.0" encoding="UTF-8"?>
</EntityDescriptor>`

func Test_IDPSSODescriptor_Attributes(t *testing.T) {
t.Parallel()
r := require.New(t)

ed := &metadata.EntityDescriptorIDPSSO{}
Expand Down
25 changes: 6 additions & 19 deletions saml/models/metadata/sp_sso_descriptor_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ var exampleSPSSODescriptor = `<EntityDescriptor
</EntityDescriptor>`

func Test_SPSSODescriptor(t *testing.T) {
t.Parallel()
r := require.New(t)

ed := &metadata.EntityDescriptorSPSSO{}
Expand Down Expand Up @@ -89,6 +90,7 @@ var exampleSLOService = `<EntityDescriptor
</EntityDescriptor>`

func Test_SPSSODescriptor_SLOService(t *testing.T) {
t.Parallel()
r := require.New(t)

ed := &metadata.EntityDescriptorSPSSO{}
Expand Down Expand Up @@ -129,6 +131,7 @@ var exampleNameIDService = `<EntityDescriptor
</EntityDescriptor>`

func Test_SPSSODescriptor_ManageNameIDService(t *testing.T) {
t.Parallel()
r := require.New(t)

ed := &metadata.EntityDescriptorSPSSO{}
Expand Down Expand Up @@ -160,6 +163,7 @@ var exampleNameIDFormats = `<EntityDescriptor
</EntityDescriptor>`

func Test_SPSSODescriptor_NameIDFormats(t *testing.T) {
t.Parallel()
r := require.New(t)

ed := &metadata.EntityDescriptorSPSSO{}
Expand Down Expand Up @@ -193,6 +197,7 @@ var exampleACS = `<EntityDescriptor
</EntityDescriptor>`

func Test_SPSSODescriptor_ACS(t *testing.T) {
t.Parallel()
r := require.New(t)

ed := &metadata.EntityDescriptorSPSSO{}
Expand Down Expand Up @@ -250,6 +255,7 @@ var exampleAttributeConsumingService = `<EntityDescriptor
// <saml:AttributeValue type="xs:string">By-Tor</saml:AttributeValue>

func Test_SPSSODescriptor_AttributeConsumingService(t *testing.T) {
t.Parallel()
r := require.New(t)

ed := &metadata.EntityDescriptorSPSSO{}
Expand Down Expand Up @@ -332,22 +338,3 @@ x5Ql0ejivIJAYcMGUyA+/YwJg2FGoA==
</KeyDescriptor>
</SPSSODescriptor>
</EntityDescriptor>`

// func Test_SPSSODescriptor_KeyDescritpor(t *testing.T) {
// r := require.New(t)

// ed := &metadata.EntityDescriptor{}

// err := xml.Unmarshal([]byte(exampleKeyDescriptor), ed)
// r.NoError(err)

// keyDescriptor := ed.SPSSODescriptor[0].KeyDescriptor

// r.Len(keyDescriptor, 2)

// r.Equal(keyDescriptor[0].Use, metadata.KeyTypeSigning)
// r.NotEmpty(keyDescriptor[0].KeyInfo.X509Data, "")

// r.Equal(keyDescriptor[1].Use, metadata.KeyTypeEncryption)
// r.NotEmpty(keyDescriptor[1].KeyInfo.X509Data, "")
// }
2 changes: 1 addition & 1 deletion saml/response.go
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,7 @@ func (sp *ServiceProvider) internalParser(
for _, xcert := range kd.KeyInfo.X509Data.X509Certificates {
parsed, err := parseX509Certificate(xcert.Data)
if err != nil {
return nil, err
return nil, fmt.Errorf("%s: unable to parse cert: %w", op, err)
}
certStore.Roots = append(certStore.Roots, parsed) // append works just fine with a nil slice
}
Expand Down
5 changes: 3 additions & 2 deletions saml/response_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@ func TestServiceProvider_ParseResponse(t *testing.T) {
}

func TestServiceProvider_ParseResponseCustomACS(t *testing.T) {
t.Parallel()
r := require.New(t)

fakeTime, err := time.Parse("2006-01-02", "2015-07-15")
Expand Down Expand Up @@ -151,9 +152,9 @@ func TestServiceProvider_ParseResponseCustomACS(t *testing.T) {
)
if c.err == "" {
require.NoError(t, err)
} else {
require.ErrorContains(t, err, c.err)
return
}
require.ErrorContains(t, err, c.err)
})
}

Expand Down
6 changes: 5 additions & 1 deletion saml/sp.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ import (
)

//go:embed authn_request.gohtml
var PostBindingTempl string
var postBindingTempl string

type metadataOptions struct {
wantAssertionsSigned bool
Expand All @@ -37,6 +37,8 @@ func getMetadataOptions(opt ...Option) metadataOptions {
return opts
}

// InsecureWantAssertionsUnsigned provides a way to optionally request that you
// want insecure/unsigned assertions.
func InsecureWantAssertionsUnsigned() Option {
return func(o interface{}) {
if o, ok := o.(*metadataOptions); ok {
Expand All @@ -55,6 +57,7 @@ func WithMetadataNameIDFormat(format ...core.NameIDFormat) Option {
}
}

// WithACSServiceBinding provides an optional service binding.
func WithACSServiceBinding(b core.ServiceBinding) Option {
return func(o interface{}) {
if o, ok := o.(*metadataOptions); ok {
Expand All @@ -75,6 +78,7 @@ func WithAdditionalACSEndpoint(b core.ServiceBinding, location url.URL) Option {
}
}

// ServiceProvider defines a type for service providers
type ServiceProvider struct {
cfg *Config
}
Expand Down
12 changes: 8 additions & 4 deletions saml/sp_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ import (
)

func Test_NewServiceProvider(t *testing.T) {
t.Parallel()
r := require.New(t)
exampleURL := "http://test.me"

Expand Down Expand Up @@ -56,16 +57,17 @@ func Test_NewServiceProvider(t *testing.T) {
if c.err != "" {
r.Error(err)
r.ErrorContains(err, c.err)
} else {
r.NoError(err)
r.NotNil(got)
r.NotNil(got.Config())
return
}
r.NoError(err)
r.NotNil(got)
r.NotNil(got.Config())
})
}
}

func Test_ServiceProvider_FetchMetadata_ErrorCases(t *testing.T) {
t.Parallel()
r := require.New(t)

s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
Expand Down Expand Up @@ -117,6 +119,7 @@ func Test_ServiceProvider_FetchMetadata_ErrorCases(t *testing.T) {
}

func Test_ServiceProvider_CreateMetadata(t *testing.T) {
t.Parallel()
r := require.New(t)

entityID := "http://test.me/entity"
Expand Down Expand Up @@ -185,6 +188,7 @@ func Test_ServiceProvider_CreateMetadata(t *testing.T) {
}

func Test_CreateMetadata_Options(t *testing.T) {
t.Parallel()
r := require.New(t)

fakeURL := "http://fake.test.url"
Expand Down

0 comments on commit 154b738

Please sign in to comment.