Skip to content

Commit

Permalink
bearer: inject Token instead of User (#13)
Browse files Browse the repository at this point in the history
  • Loading branch information
unknwon authored Mar 13, 2022
1 parent 4c221ed commit 78032fe
Show file tree
Hide file tree
Showing 5 changed files with 53 additions and 17 deletions.
3 changes: 0 additions & 3 deletions auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,6 @@ import (
"crypto/subtle"
)

// User is the authenticated username that was extracted from the request.
type User string

// SecureCompare performs a constant time compare of two strings to prevent
// timing attacks.
func SecureCompare(given, actual string) bool {
Expand Down
3 changes: 3 additions & 0 deletions basic.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@ import (

const basicPrefix = "Basic "

// User is the authenticated username that was extracted from the request.
type User string

// Basic returns a middleware handler that injects auth.User into the request
// context upon successful basic authentication. The handler responds
// http.StatusUnauthorized when authentication fails.
Expand Down
24 changes: 20 additions & 4 deletions basic_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import (
"testing"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"

"github.com/flamego/flamego"
)
Expand All @@ -23,35 +24,41 @@ func TestBasic(t *testing.T) {
username string
password string
wantCode int
wantBody string
}{
{
name: "good",
username: "foo",
password: "bar",
wantCode: http.StatusOK,
wantBody: "foo",
},
{
name: "bad",
username: "bar",
password: "foo",
wantCode: http.StatusUnauthorized,
wantBody: "Unauthorized\n",
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
f := flamego.NewWithLogger(&bytes.Buffer{})
f.Use(Basic("foo", "bar"))
f.Get("/", func() {})
f.Get("/", func(user User) string {
return string(user)
})

resp := httptest.NewRecorder()
req, err := http.NewRequest(http.MethodGet, "/", nil)
assert.Nil(t, err)
require.NoError(t, err)

auth := strings.Join([]string{test.username, test.password}, ":")
req.Header.Set("Authorization", basicPrefix+base64.StdEncoding.EncodeToString([]byte(auth)))
f.ServeHTTP(resp, req)

assert.Equal(t, test.wantCode, resp.Code)
assert.Equal(t, test.wantBody, resp.Body.String())
})
}
}
Expand All @@ -61,31 +68,37 @@ func TestBasicFunc(t *testing.T) {
name string
header string
wantCode int
wantBody string
}{
{
name: "primary password",
header: basicPrefix + "Zm9vOmJhcg==", // foo:bar
wantCode: http.StatusOK,
wantBody: "foo",
},
{
name: "secondary password",
header: basicPrefix + "Zm9vOmJheg==", // foo:baz
wantCode: http.StatusOK,
wantBody: "foo",
},
{
name: "wrong password",
header: basicPrefix + "Zm9vOm5vcGU=", // foo:nope
wantCode: http.StatusUnauthorized,
wantBody: "Unauthorized\n",
},
{
name: "bad prefix",
header: "Zm9vOmJheg==", // foo:baz
wantCode: http.StatusUnauthorized,
wantBody: "Unauthorized\n",
},
{
name: "bad encoding",
header: basicPrefix + "Zm9vOm5",
wantCode: http.StatusUnauthorized,
wantBody: "Unauthorized\n",
},
}
for _, test := range tests {
Expand All @@ -94,16 +107,19 @@ func TestBasicFunc(t *testing.T) {
f.Use(BasicFunc(func(username, password string) bool {
return username == "foo" && (password == "bar" || password == "baz")
}))
f.Get("/", func() {})
f.Get("/", func(user User) string {
return string(user)
})

resp := httptest.NewRecorder()
req, err := http.NewRequest(http.MethodGet, "/", nil)
assert.Nil(t, err)
require.NoError(t, err)

req.Header.Set("Authorization", test.header)
f.ServeHTTP(resp, req)

assert.Equal(t, test.wantCode, resp.Code)
assert.Equal(t, test.wantBody, resp.Body.String())
})
}
}
11 changes: 8 additions & 3 deletions bearer.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,9 @@ import (

var bearerPrefix = "Bearer "

// Token is the authenticated token that was extracted from the request.
type Token string

// Bearer returns a middleware handler that injects auth.User (empty string)
// into the request context upon successful bearer authentication. The handler
// responds http.StatusUnauthorized when authentication fails.
Expand All @@ -22,7 +25,7 @@ func Bearer(token string) flamego.Handler {
bearerUnauthorized(c.ResponseWriter())
return
}
c.Map(User(""))
c.Map(Token(token))
})
}

Expand All @@ -37,11 +40,13 @@ func BearerFunc(fn func(token string) bool) flamego.Handler {
bearerUnauthorized(c.ResponseWriter())
return
}
if !fn(auth[n:]) {

token := auth[n:]
if !fn(token) {
bearerUnauthorized(c.ResponseWriter())
return
}
c.Map(User(""))
c.Map(Token(token))
})
}

Expand Down
29 changes: 22 additions & 7 deletions bearer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,39 +12,46 @@ import (

"github.com/flamego/flamego"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

func TestBearer(t *testing.T) {
tests := []struct {
name string
token string
wantCode int
wantBody string
}{
{
name: "good",
token: "foo",
wantCode: http.StatusOK,
wantBody: "foo",
},
{
name: "bad",
token: "bar",
wantCode: http.StatusUnauthorized,
wantBody: "Unauthorized\n",
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
f := flamego.NewWithLogger(&bytes.Buffer{})
f.Use(Bearer("foo"))
f.Get("/", func() {})
f.Get("/", func(token Token) string {
return string(token)
})

resp := httptest.NewRecorder()
req, err := http.NewRequest(http.MethodGet, "/", nil)
assert.Nil(t, err)
require.NoError(t, err)

req.Header.Set("Authorization", bearerPrefix+test.token)
f.ServeHTTP(resp, req)

assert.Equal(t, test.wantCode, resp.Code)
assert.Equal(t, test.wantBody, resp.Body.String())
})
}
}
Expand All @@ -54,26 +61,31 @@ func TestBearerFunc(t *testing.T) {
name string
header string
wantCode int
wantBody string
}{
{
name: "primary password",
name: "primary token",
header: bearerPrefix + "foo",
wantCode: http.StatusOK,
wantBody: "foo",
},
{
name: "secondary password",
name: "secondary token",
header: bearerPrefix + "bar",
wantCode: http.StatusOK,
wantBody: "bar",
},
{
name: "wrong password",
name: "wrong token",
header: bearerPrefix + "nope",
wantCode: http.StatusUnauthorized,
wantBody: "Unauthorized\n",
},
{
name: "bad prefix",
header: "foo",
wantCode: http.StatusUnauthorized,
wantBody: "Unauthorized\n",
},
}
for _, test := range tests {
Expand All @@ -82,16 +94,19 @@ func TestBearerFunc(t *testing.T) {
f.Use(BearerFunc(func(token string) bool {
return token == "foo" || token == "bar"
}))
f.Get("/", func() {})
f.Get("/", func(token Token) string {
return string(token)
})

resp := httptest.NewRecorder()
req, err := http.NewRequest(http.MethodGet, "/", nil)
assert.Nil(t, err)
require.NoError(t, err)

req.Header.Set("Authorization", test.header)
f.ServeHTTP(resp, req)

assert.Equal(t, test.wantCode, resp.Code)
assert.Equal(t, test.wantBody, resp.Body.String())
})
}
}

0 comments on commit 78032fe

Please sign in to comment.