diff --git a/saml/is_nil.go b/saml/is_nil.go new file mode 100644 index 0000000..c54654e --- /dev/null +++ b/saml/is_nil.go @@ -0,0 +1,15 @@ +package saml + +import "reflect" + +// isNil reports if a is nil +func isNil(a any) bool { + if a == nil { + return true + } + switch reflect.TypeOf(a).Kind() { + case reflect.Ptr, reflect.Map, reflect.Chan, reflect.Slice, reflect.Func: + return reflect.ValueOf(a).IsNil() + } + return false +} diff --git a/saml/is_nil_test.go b/saml/is_nil_test.go new file mode 100644 index 0000000..ff9fd40 --- /dev/null +++ b/saml/is_nil_test.go @@ -0,0 +1,56 @@ +package saml + +import ( + "fmt" + "testing" + + "github.com/stretchr/testify/assert" +) + +func Test_isNil(t *testing.T) { + t.Parallel() + + var testErrNilPtr *testError + var testMapNilPtr map[string]struct{} + var testArrayNilPtr *[1]string + var testChanNilPtr *chan string + var testSliceNilPtr *[]string + var testFuncNil func() + + var testChanString chan string + + tc := []struct { + i any + want bool + }{ + {i: &testError{}, want: false}, + {i: testError{}, want: false}, + {i: &map[string]struct{}{}, want: false}, + {i: map[string]struct{}{}, want: false}, + {i: [1]string{}, want: false}, + {i: &[1]string{}, want: false}, + {i: &testChanString, want: false}, + {i: "string", want: false}, + {i: []string{}, want: false}, + {i: func() {}, want: false}, + {i: nil, want: true}, + {i: testErrNilPtr, want: true}, + {i: testMapNilPtr, want: true}, + {i: testArrayNilPtr, want: true}, + {i: testChanNilPtr, want: true}, + {i: testChanString, want: true}, + {i: testSliceNilPtr, want: true}, + {i: testFuncNil, want: true}, + } + + for i, tc := range tc { + t.Run(fmt.Sprintf("test #%d", i+1), func(t *testing.T) { + assert := assert.New(t) + assert.Equal(tc.want, isNil(tc.i)) + }) + } +} + +type testError struct{} + +func (*testError) Error() string { return "error" } diff --git a/saml/response.go b/saml/response.go index 41c7ede..1af9a47 100644 --- a/saml/response.go +++ b/saml/response.go @@ -156,6 +156,10 @@ func (sp *ServiceProvider) internalParser( clock clockwork.Clock, ) (*saml2.SAMLServiceProvider, error) { const op = "saml.(ServiceProvider).internalParser" + switch { + case isNil(clock): + return nil, fmt.Errorf("%s: missing clock: %w", op, ErrInvalidParameter) + } idpMetadata, err := sp.IDPMetadata() if err != nil { return nil, fmt.Errorf("%s: %w", op, err)