Skip to content

Commit

Permalink
Merge pull request #26 from davidallendj/bss-token-fetch
Browse files Browse the repository at this point in the history
Changed how BSS fetches a new token
  • Loading branch information
davidallendj authored Mar 25, 2024
2 parents e832626 + 3888abb commit 929721d
Show file tree
Hide file tree
Showing 3 changed files with 75 additions and 70 deletions.
6 changes: 3 additions & 3 deletions cmd/boot-script-service/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -102,8 +102,8 @@ var (
jwksURL = ""
sqlDbOpts = ""
spireServiceURL = "https://spire-tokens.spire:54440"
oauth2AdminBaseURL = "http://127.0.0.1:4445"
oauth2PublicBaseURL = "http://127.0.0.1:4444"
oauth2AdminBaseURL = "http://127.0.0.1:3333"
oauth2PublicBaseURL = "http://127.0.0.1:3333"
)

func parseEnv(evar string, v interface{}) (ret error) {
Expand Down Expand Up @@ -453,7 +453,7 @@ func main() {
// try and fetch JWKS from issuer
if jwksURL != "" {
for i := uint64(0); i <= authRetryCount; i++ {
err := loadPublicKeyFromURL(jwksURL)
err := fetchPublicKey(jwksURL)
if err != nil {
log.Printf("failed to initialize auth token: %v", err)
time.Sleep(5 * time.Second)
Expand Down
113 changes: 72 additions & 41 deletions cmd/boot-script-service/oauth.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ package main

import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
Expand All @@ -23,6 +24,9 @@ import (
"net/url"
"time"

"github.com/go-chi/jwtauth/v5"
"github.com/lestrrat-go/jwx/jwa"
"github.com/lestrrat-go/jwx/jwk"
"github.com/lestrrat-go/jwx/jwt"
)

Expand Down Expand Up @@ -51,6 +55,30 @@ func (nc nowClock) Now() time.Time {
return time.Now()
}

// fetchPublicKey fetches the JWKS (JSON Key Set) needed to verify JWTs with issuer.
func fetchPublicKey(url string) error {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
set, err := jwk.Fetch(ctx, url)
if err != nil {
return fmt.Errorf("%v", err)
}
for it := set.Iterate(context.Background()); it.Next(context.Background()); {
pair := it.Pair()
key := pair.Value.(jwk.Key)

var rawkey interface{}
if err := key.Raw(&rawkey); err != nil {
continue
}

tokenAuth = jwtauth.New(jwa.RS256.String(), nil, rawkey)
return nil
}

return fmt.Errorf("failed to load public key: %v", err)
}

func (client *OAuthClient) CreateOAuthClient(registerUrl string) ([]byte, error) {
// hydra endpoint: POST /clients
data := []byte(`{
Expand Down Expand Up @@ -123,6 +151,11 @@ func (client *OAuthClient) PerformTokenGrant(remoteUrl string) (string, error) {
return "", fmt.Errorf("failed to unmarshal response body: %v", err)
}

accessToken := rjson["access_token"]
if accessToken == nil {
return "", fmt.Errorf("no access token found")
}

return rjson["access_token"].(string), nil
}

Expand All @@ -133,46 +166,6 @@ func QuoteArrayStrings(arr []string) []string {
return arr
}

// RequestClientCreds performs the requests to the OAuth2 server to obtain an
// access token for this client (BSS).
//
// 1. Register as OAuth2 client.
// 2. Authorize OAuth2 client that was created.
// 3. Obtain access token if OAuth2 client is authorized.
//
// Returns the OAuthClient struct containing the client ID, secret, etc. as well
// as the access token and an error if one occurred.
func (client *OAuthClient) RequestClientCreds() (accessToken string, err error) {
var (
url string
resp []byte
)

url = oauth2AdminBaseURL + "/admin/clients"
log.Printf("Attempting to register OAuth2 client")
debugf("Sending request to %s", url)
resp, err = client.CreateOAuthClient(url)
if err != nil {
err = fmt.Errorf("Failed to register OAuth2 client: %v", err)
debugf("Response: %v", string(resp))
return
}
log.Printf("Successfully registered OAuth2 client")
debugf("Client ID: %s", client.Id)

url = oauth2PublicBaseURL + "/oauth2/token"
log.Printf("Attempting to fetch token from authorization server")
debugf("Sending request to %s", url)
accessToken, err = client.PerformTokenGrant(url)
if err != nil {
err = fmt.Errorf("Failed to fetch token from authorization server: %v", err)
return
}
log.Printf("Successfully fetched token")

return
}

// PollClientCreds tries retryCount times every retryInterval seconds to request
// client credentials and an access token (JWT) from the OAuth2 server. If
// attempts are exhausted or an invalid retryInterval is passed, an error is
Expand All @@ -184,7 +177,7 @@ func (client *OAuthClient) PollClientCreds(retryCount, retryInterval uint64) err
}
for i := uint64(0); i < retryCount; i++ {
log.Printf("Attempting to obtain access token (attempt %d/%d)", i+1, retryCount)
token, err := client.RequestClientCreds()
token, err := client.FetchAccessToken(oauth2AdminBaseURL + "/token")
if err != nil {
log.Printf("Failed to obtain client credentials and token: %v", err)
time.Sleep(retryDuration)
Expand Down Expand Up @@ -259,3 +252,41 @@ func JWTIsValid(jwtStr string) (jwtValid bool, reason, err error) {

return
}

// FetchAccessToken fetches an access token for this client (BSS).
//
// Returns the access token string necessary to supply for authorization requests.
func (client *OAuthClient) FetchAccessToken(remoteUrl string) (string, error) {
// opaal endpoint: /token
headers := map[string][]string{
"no-browser": {},
}
req, err := http.NewRequest(http.MethodPost, remoteUrl, nil)
req.Header = headers
if err != nil {
return "", fmt.Errorf("failed to make request: %s", err)
}
res, err := client.Do(req)
if err != nil {
return "", fmt.Errorf("failed to do request: %v", err)
}
defer res.Body.Close()

b, err := io.ReadAll(res.Body)
if err != nil {
return "", fmt.Errorf("failed to read response body: %v", err)
}

var rjson map[string]any
err = json.Unmarshal(b, &rjson)
if err != nil {
return "", fmt.Errorf("failed to unmarshal response body: %v", err)
}

accessToken := rjson["access_token"]
if accessToken == nil {
return "", fmt.Errorf("no access token found")
}

return rjson["access_token"].(string), nil
}
26 changes: 0 additions & 26 deletions cmd/boot-script-service/routers.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@
package main

import (
"context"
"fmt"
"net/http"
"time"
Expand All @@ -45,8 +44,6 @@ import (
"github.com/go-chi/chi/middleware"
"github.com/go-chi/chi/v5"
"github.com/go-chi/jwtauth/v5"
"github.com/lestrrat-go/jwx/jwa"
"github.com/lestrrat-go/jwx/jwk"
)

const (
Expand All @@ -62,29 +59,6 @@ var (
tokenAuth *jwtauth.JWTAuth
)

func loadPublicKeyFromURL(url string) error {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
set, err := jwk.Fetch(ctx, url)
if err != nil {
return fmt.Errorf("%v", err)
}
for it := set.Iterate(context.Background()); it.Next(context.Background()); {
pair := it.Pair()
key := pair.Value.(jwk.Key)

var rawkey interface{}
if err := key.Raw(&rawkey); err != nil {
continue
}

tokenAuth = jwtauth.New(jwa.RS256.String(), nil, rawkey)
return nil
}

return fmt.Errorf("failed to load public key: %v", err)
}

func initHandlers() *chi.Mux {
router := chi.NewRouter()
router.Use(middleware.RequestID)
Expand Down

0 comments on commit 929721d

Please sign in to comment.