diff --git a/device/device_flow.go b/device/device_flow.go index e6a6ee9..0447980 100644 --- a/device/device_flow.go +++ b/device/device_flow.go @@ -54,12 +54,31 @@ type CodeResponse struct { Interval int } +// AuthRequestEditorFn defines the function signature for setting additional form values. +type AuthRequestEditorFn func(*url.Values) + +// WithAudience sets the audience parameter in the request. +func WithAudience(audience string) AuthRequestEditorFn { + return func(values *url.Values) { + if audience != "" { + values.Add("audience", audience) + } + } +} + // RequestCode initiates the authorization flow by requesting a code from uri. -func RequestCode(c httpClient, uri string, clientID string, scopes []string) (*CodeResponse, error) { - resp, err := api.PostForm(c, uri, url.Values{ +func RequestCode(c httpClient, uri string, clientID string, scopes []string, + optionalRequestParams ...AuthRequestEditorFn) (*CodeResponse, error) { + values := url.Values{ "client_id": {clientID}, "scope": {strings.Join(scopes, " ")}, - }) + } + + for _, fn := range optionalRequestParams { + fn(&values) + } + + resp, err := api.PostForm(c, uri, values) if err != nil { return nil, err } diff --git a/device/device_flow_test.go b/device/device_flow_test.go index b43d18c..08fbd0f 100644 --- a/device/device_flow_test.go +++ b/device/device_flow_test.go @@ -51,6 +51,7 @@ func TestRequestCode(t *testing.T) { url string clientID string scopes []string + audience string } tests := []struct { name string @@ -126,6 +127,42 @@ func TestRequestCode(t *testing.T) { }, }, }, + { + name: "with audience", + args: args{ + http: apiClient{ + stubs: []apiStub{ + { + body: "verification_uri=http://verify.me&interval=5&expires_in=99&device_code=DEVIC&user_code=123-abc&verification_uri_complete=http://verify.me/?code=123-abc", + status: 200, + contentType: "application/x-www-form-urlencoded; charset=utf-8", + }, + }, + }, + url: "https://github.com/oauth", + clientID: "CLIENT-ID", + scopes: []string{"repo", "gist"}, + audience: "https://api.github.com", + }, + want: &CodeResponse{ + DeviceCode: "DEVIC", + UserCode: "123-abc", + VerificationURI: "http://verify.me", + VerificationURIComplete: "http://verify.me/?code=123-abc", + ExpiresIn: 99, + Interval: 5, + }, + posts: []postArgs{ + { + url: "https://github.com/oauth", + params: url.Values{ + "client_id": {"CLIENT-ID"}, + "scope": {"repo gist"}, + "audience": {"https://api.github.com"}, + }, + }, + }, + }, { name: "unsupported", args: args{ @@ -237,7 +274,8 @@ func TestRequestCode(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - got, err := RequestCode(&tt.args.http, tt.args.url, tt.args.clientID, tt.args.scopes) + got, err := RequestCode(&tt.args.http, tt.args.url, + tt.args.clientID, tt.args.scopes, WithAudience(tt.args.audience)) if (err != nil) != (tt.wantErr != "") { t.Errorf("RequestCode() error = %v, wantErr %v", err, tt.wantErr) return diff --git a/oauth.go b/oauth.go index 7e38c71..5b98c38 100644 --- a/oauth.go +++ b/oauth.go @@ -68,6 +68,8 @@ type Flow struct { Host *Host // OAuth scopes to request from the user. Scopes []string + // OAuth audience to request from the user. + Audience string // OAuth application ID. ClientID string // OAuth application secret. Only applicable in web application flow. diff --git a/oauth_device.go b/oauth_device.go index aef280f..f993eaa 100644 --- a/oauth_device.go +++ b/oauth_device.go @@ -39,7 +39,8 @@ func (oa *Flow) DeviceFlow() (*api.AccessToken, error) { host = parsedHost } - code, err := device.RequestCode(httpClient, host.DeviceCodeURL, oa.ClientID, oa.Scopes) + code, err := device.RequestCode(httpClient, host.DeviceCodeURL, + oa.ClientID, oa.Scopes, device.WithAudience(oa.Audience)) if err != nil { return nil, err } diff --git a/oauth_webapp.go b/oauth_webapp.go index 9c7af15..360c53b 100644 --- a/oauth_webapp.go +++ b/oauth_webapp.go @@ -32,6 +32,7 @@ func (oa *Flow) WebAppFlow() (*api.AccessToken, error) { ClientID: oa.ClientID, RedirectURI: oa.CallbackURI, Scopes: oa.Scopes, + Audience: oa.Audience, AllowSignup: true, } browserURL, err := flow.BrowserURL(host.AuthorizeURL, params) diff --git a/webapp/webapp_flow.go b/webapp/webapp_flow.go index 49c6551..a6c96ae 100644 --- a/webapp/webapp_flow.go +++ b/webapp/webapp_flow.go @@ -47,6 +47,7 @@ type BrowserParams struct { ClientID string RedirectURI string Scopes []string + Audience string LoginHandle string AllowSignup bool } @@ -68,6 +69,10 @@ func (flow *Flow) BrowserURL(baseURL string, params BrowserParams) (string, erro q.Set("redirect_uri", ru.String()) q.Set("scope", strings.Join(params.Scopes, " ")) q.Set("state", flow.state) + + if params.Audience != "" { + q.Set("audience", params.Audience) + } if params.LoginHandle != "" { q.Set("login", params.LoginHandle) } diff --git a/webapp/webapp_flow_test.go b/webapp/webapp_flow_test.go index f4e4355..110ff0c 100644 --- a/webapp/webapp_flow_test.go +++ b/webapp/webapp_flow_test.go @@ -51,6 +51,24 @@ func TestFlow_BrowserURL(t *testing.T) { want: "https://github.com/authorize?client_id=CLIENT-ID&redirect_uri=http%3A%2F%2F127.0.0.1%3A12345%2Fhello&scope=repo+read%3Aorg&state=xy%2Fz", wantErr: false, }, + { + name: "happy path with audience", + fields: fields{ + server: server, + state: "xy/z", + }, + args: args{ + baseURL: "https://github.com/authorize", + params: BrowserParams{ + ClientID: "CLIENT-ID", + RedirectURI: "http://127.0.0.1/hello", + Scopes: []string{"repo", "read:org"}, + AllowSignup: true, + Audience: "https://api.github.com", + }, + }, + want: "https://github.com/authorize?audience=https%3A%2F%2Fapi.github.com&client_id=CLIENT-ID&redirect_uri=http%3A%2F%2F127.0.0.1%3A12345%2Fhello&scope=repo+read%3Aorg&state=xy%2Fz", + }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) {