Skip to content

Commit

Permalink
feat: Support for tenantID in azuread provider
Browse files Browse the repository at this point in the history
Signed-off-by: Pedro Parra Ortega <pedro.parraortega@enreach.com>

fix: revert non desired changes

Signed-off-by: Pedro Parra Ortega <pedro.parraortega@enreach.com>
  • Loading branch information
ppodevlabs committed Dec 16, 2024
1 parent 0c63ed9 commit ca2d7c1
Show file tree
Hide file tree
Showing 3 changed files with 47 additions and 22 deletions.
2 changes: 1 addition & 1 deletion gothic/gothic.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,14 @@ import (
"encoding/base64"
"errors"
"fmt"
"github.com/go-chi/chi/v5"
"io"
"io/ioutil"
"net/http"
"net/url"
"os"
"strings"

"github.com/go-chi/chi/v5"
"github.com/gorilla/mux"
"github.com/gorilla/sessions"
"github.com/markbates/goth"
Expand Down
65 changes: 45 additions & 20 deletions providers/azuread/azuread.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,41 +16,50 @@ import (
)

const (
authURL string = "https://login.microsoftonline.com/common/oauth2/authorize"
tokenURL string = "https://login.microsoftonline.com/common/oauth2/token"
authURLTemplate string = "https://login.microsoftonline.com/%s/oauth2/authorize"
tokenURLTemplate string = "https://login.microsoftonline.com/%s/oauth2/token"
endpointProfile string = "https://graph.windows.net/me?api-version=1.6"
graphAPIResource string = "https://graph.windows.net/"
commonTenant string = "common"
)

// New creates a new AzureAD provider, and sets up important connection details.
// You should always call `AzureAD.New` to get a new Provider. Never try to create
// one manually.
func New(clientKey, secret, callbackURL string, resources []string, scopes ...string) *Provider {
func New(clientKey, secret, callbackURL string, opts ProviderOpts) *Provider {
p := &Provider{
ClientKey: clientKey,
Secret: secret,
CallbackURL: callbackURL,
providerName: "azuread",
}

p.resources = make([]string, 0, 1+len(resources))
p.resources = make([]string, 0, 1+len(opts.Resources))
p.resources = append(p.resources, graphAPIResource)
p.resources = append(p.resources, resources...)
p.resources = append(p.resources, opts.Resources...)

p.config = newConfig(p, scopes)
p.config = newConfig(p, opts)
return p
}

// Provider is the implementation of `goth.Provider` for accessing AzureAD.
type Provider struct {
ClientKey string
Secret string
CallbackURL string
HTTPClient *http.Client
config *oauth2.Config
providerName string
resources []string
}
type (
Provider struct {
ClientKey string
Secret string
CallbackURL string
HTTPClient *http.Client
config *oauth2.Config
providerName string
resources []string
}

ProviderOpts struct {
Resources []string
Scopes []string
TenantID string
}
)

// Name is the name used to retrieve this provider later.
func (p *Provider) Name() string {
Expand Down Expand Up @@ -132,20 +141,20 @@ func (p *Provider) RefreshToken(refreshToken string) (*oauth2.Token, error) {
return newToken, err
}

func newConfig(provider *Provider, scopes []string) *oauth2.Config {
func newConfig(provider *Provider, opts ProviderOpts) *oauth2.Config {
c := &oauth2.Config{
ClientID: provider.ClientKey,
ClientSecret: provider.Secret,
RedirectURL: provider.CallbackURL,
Endpoint: oauth2.Endpoint{
AuthURL: authURL,
TokenURL: tokenURL,
AuthURL: authURL(opts.TenantID),
TokenURL: tokenURL(opts.TenantID),
},
Scopes: []string{},
}

if len(scopes) > 0 {
for _, scope := range scopes {
if len(opts.Scopes) > 0 {
for _, scope := range opts.Scopes {
c.Scopes = append(c.Scopes, scope)
}
} else {
Expand Down Expand Up @@ -185,3 +194,19 @@ func userFromReader(r io.Reader, user *goth.User) error {
func authorizationHeader(session *Session) (string, string) {
return "Authorization", fmt.Sprintf("Bearer %s", session.AccessToken)
}

func authURL(tenantID string) string {
if tenantID != "" {
return fmt.Sprintf(authURLTemplate, tenantID)
} else {
return fmt.Sprintf(authURLTemplate, commonTenant)
}
}

func tokenURL(tenantID string) string {
if tenantID != "" {
return fmt.Sprintf(tokenURLTemplate, tenantID)
} else {
return fmt.Sprintf(tokenURLTemplate, commonTenant)
}
}
2 changes: 1 addition & 1 deletion providers/azuread/azuread_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,5 +51,5 @@ func Test_SessionFromJSON(t *testing.T) {
}

func azureadProvider() *azuread.Provider {
return azuread.New(os.Getenv("AZUREAD_KEY"), os.Getenv("AZUREAD_SECRET"), "/foo", nil)
return azuread.New(os.Getenv("AZUREAD_KEY"), os.Getenv("AZUREAD_SECRET"), "/foo", azuread.ProviderOpts{})
}

0 comments on commit ca2d7c1

Please sign in to comment.