Skip to content

Commit

Permalink
Merge pull request #2 from synackd/bss-auth-impl-token
Browse files Browse the repository at this point in the history
Add JWT verification and use JWT with SMD
  • Loading branch information
davidallendj authored Mar 8, 2024
2 parents c1ec85f + 9cd350a commit 8f40cfa
Show file tree
Hide file tree
Showing 4 changed files with 384 additions and 61 deletions.
15 changes: 10 additions & 5 deletions Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -75,11 +75,16 @@ ENV BSS_HSM_RETRIEVAL_DELAY=10
# is disabled.
# BSS_JWKS_URL=""
#
# Base URL of the OAUTH2 server to use for client authorizations when
# JWT authentication is enabled. This is used to authorize BSS to be
# able to communicate with protected SMD endpoints when it is queried
# for a boot script.
# BSS_OAUTH2_BASE_URL=http://127.0.0.1:4444
# Base URL of the Oauth2 server admin endpoints to use for client authorizations
# when JWT authentication is enabled. This is used to authorize BSS via a client
# credentials grant to be able to communicate with protected SMD endpoints when
# it is queried for a boot script.
# BSS_OAUTH2_ADMIN_BASE_URL=http://127.0.0.1:4445
#
# Base URL of the OAuth2 server public endpoints to use for non-admin requests
# like a client (e.g. BSS) requesting an access token after it has been
# authorized.
# BSS_OAUTH2_USER_BASE_URL=http://127.0.0.1:4444

# Etcd variables with default values:
#
Expand Down
77 changes: 38 additions & 39 deletions cmd/boot-script-service/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,11 +53,14 @@ import (
"github.com/OpenCHAMI/bss/internal/postgres"
)

const kvDefaultRetryCount uint64 = 10
const kvDefaultRetryWait uint64 = 5
const sqlDefaultRetryCount uint64 = 10
const sqlDefaultRetryWait uint64 = 5
const authDefaultRetryCount uint64 = 10
const (
kvDefaultRetryCount uint64 = 10
kvDefaultRetryWait uint64 = 5
sqlDefaultRetryCount uint64 = 10
sqlDefaultRetryWait uint64 = 5
authDefaultRetryCount uint64 = 10
authDefaultRetryWait uint64 = 5
)

var (
httpListen = ":27778"
Expand All @@ -84,22 +87,23 @@ var (
// TODO: Set the default to a well known link local address when we have it.
// This will also mean we change the virtual service into an Ingress with
// this well known IP.
advertiseAddress = "" // i.e. http://{IP to reach this service}
insecure = false
debugFlag = false
kvstore hmetcd.Kvi
retryDelay = uint(30)
hsmRetrievalDelay = uint(10)
sqlRetryCount = sqlDefaultRetryCount
sqlRetryWait = sqlDefaultRetryWait
notifier *ScnNotifier
useSQL = false // Use ETCD by default
authRetryCount = authDefaultRetryCount
jwksURL = ""
accessToken = ""
sqlDbOpts = ""
spireServiceURL = "https://spire-tokens.spire:54440"
oauth2BaseURL = "http://127.0.0.1:4444"
advertiseAddress = "" // i.e. http://{IP to reach this service}
insecure = false
debugFlag = false
kvstore hmetcd.Kvi
retryDelay = uint(30)
hsmRetrievalDelay = uint(10)
sqlRetryCount = sqlDefaultRetryCount
sqlRetryWait = sqlDefaultRetryWait
notifier *ScnNotifier
useSQL = false // Use ETCD by default
authRetryCount = authDefaultRetryCount
authRetryWait = authDefaultRetryWait
jwksURL = ""
sqlDbOpts = ""
spireServiceURL = "https://spire-tokens.spire:54440"
oauth2AdminBaseURL = "http://127.0.0.1:4445"
oauth2PublicBaseURL = "http://127.0.0.1:4444"
)

func parseEnv(evar string, v interface{}) (ret error) {
Expand Down Expand Up @@ -304,13 +308,21 @@ func parseEnvVars() error {
if parseErr != nil {
errList = append(errList, fmt.Errorf("BSS_AUTH_RETRY_COUNT: %q", parseErr))
}
parseErr = parseEnv("BSS_AUTH_RETRY_WAIT", &authRetryWait)
if parseErr != nil {
errList = append(errList, fmt.Errorf("BSS_AUTH_RETRY_WAIT: %q", parseErr))
}
parseErr = parseEnv("BSS_JWKS_URL", &jwksURL)
if parseErr != nil {
errList = append(errList, fmt.Errorf("BSS_JWKS_URL: %q", parseErr))
}
parseErr = parseEnv("BSS_OAUTH2_BASE_URL", &oauth2BaseURL)
parseErr = parseEnv("BSS_OAUTH2_ADMIN_BASE_URL", &oauth2AdminBaseURL)
if parseErr != nil {
errList = append(errList, fmt.Errorf("BSS_OAUTH2_BASE_URL: %q", parseErr))
errList = append(errList, fmt.Errorf("BSS_OAUTH2_ADMIN_BASE_URL: %q", parseErr))
}
parseErr = parseEnv("BSS_OAUTH2_PUBLIC_BASE_URL", &oauth2PublicBaseURL)
if parseErr != nil {
errList = append(errList, fmt.Errorf("BSS_OAUTH2_PUBLIC_BASE_URL: %q", parseErr))
}

//
Expand Down Expand Up @@ -407,14 +419,16 @@ func parseCmdLine() {
flag.StringVar(&sqlUser, "postgres-username", sqlUser, "(BSS_DBUSER) Postgres username")
flag.StringVar(&sqlPass, "postgres-password", sqlPass, "(BSS_DBPASS) Postgres password")
flag.StringVar(&jwksURL, "jwks-url", jwksURL, "(BSS_JWKS_URL) Set the JWKS URL to fetch the public key for authorization (enables authentication)")
flag.StringVar(&oauth2BaseURL, "oauth2-base-url", oauth2BaseURL, "(BSS_OAUTH2_BASE_URL) Base URL of the OAUTH2 server for client authorizations")
flag.StringVar(&oauth2AdminBaseURL, "oauth2-admin-base-url", oauth2AdminBaseURL, "(BSS_OAUTH2_ADMIN_BASE_URL) Base URL of the OAUTH2 server admin endpoints for client authorizations")
flag.StringVar(&oauth2PublicBaseURL, "oauth2-public-base-url", oauth2PublicBaseURL, "(BSS_OAUTH2_PUBLIC_BASE_URL) Base URL of the OAUTH2 server public endpoints (e.g. for token grants)")
flag.BoolVar(&insecure, "insecure", insecure, "(BSS_INSECURE) Don't enforce https certificate security")
flag.BoolVar(&debugFlag, "debug", debugFlag, "(BSS_DEBUG) Enable debug output")
flag.BoolVar(&useSQL, "postgres", useSQL, "(BSS_USESQL) Use Postgres instead of ETCD")
flag.UintVar(&retryDelay, "retry-delay", retryDelay, "(BSS_RETRY_DELAY) Retry delay in seconds")
flag.UintVar(&hsmRetrievalDelay, "hsm-retrieval-delay", hsmRetrievalDelay, "(BSS_HSM_RETRIEVAL_DELAY) SM Retrieval delay in seconds")
flag.UintVar(&sqlPort, "postgres-port", sqlPort, "(BSS_DBPORT) Postgres port")
flag.Uint64Var(&authRetryCount, "auth-retry-count", authRetryCount, "(BSS_AUTH_RETRY_COUNT) Retry fetching JWKS public key set")
flag.Uint64Var(&authRetryWait, "auth-retry-wait", authRetryWait, "(BSS_AUTH_RETRY_WAIT) Interval in seconds between authentication request attempts")
flag.Uint64Var(&sqlRetryCount, "postgres-retry-count", sqlRetryCount, "(BSS_SQL_RETRY_COUNT) Amount of times to retry connecting to Postgres")
flag.Uint64Var(&sqlRetryWait, "postgres-retry-wait", sqlRetryCount, "(BSS_SQL_RETRY_WAIT) Interval in seconds between connection attempts to Postgres")
flag.Parse()
Expand Down Expand Up @@ -449,21 +463,6 @@ func main() {
break
}
}
// register oauth2 client and receive token
var client OAuthClient
_, err = client.CreateOAuthClient(oauth2BaseURL+"/admin/clients")
if err != nil {
log.Fatalf("failed to register OAuth client: %v", err)
}
_, err = client.AuthorizeOAuthClient(oauth2BaseURL+"/oauth2/auth")
if err != nil {
log.Fatalf("failed to authorize OAuth client: %v", err)
}
accessToken, err = client.PerformTokenGrant(oauth2BaseURL+"/oauth2/token")
if err != nil {
log.Fatalf("failed to fetch token from authorization server: %v", err)
}
log.Printf("Access Token: %v\n", accessToken)

var svcOpts string
if insecure {
Expand Down
179 changes: 175 additions & 4 deletions cmd/boot-script-service/oauth.go
Original file line number Diff line number Diff line change
@@ -1,15 +1,33 @@
// NOTE: Triad License goes here
// Copyright © 2024 Triad National Security, LLC. All rights reserved.
//
// This program was produced under U.S. Government contract 89233218CNA000001
// for Los Alamos National Laboratory (LANL), which is operated by Triad
// National Security, LLC for the U.S. Department of Energy/National Nuclear
// Security Administration. All rights in the program are reserved by Triad
// National Security, LLC, and the U.S. Department of Energy/National Nuclear
// Security Administration. The Government is granted for itself and others
// acting on its behalf a nonexclusive, paid-up, irrevocable worldwide license
// in this material to reproduce, prepare derivative works, distribute copies to
// the public, perform publicly and display publicly, and to permit others to do
// so.

package main

import (
"bytes"
"encoding/json"
"fmt"
"io"
"log"
"net/http"
"net/url"
"time"

"github.com/lestrrat-go/jwx/jwt"
)

var accessToken = ""

type OAuthClient struct {
http.Client
Id string
Expand All @@ -18,6 +36,21 @@ type OAuthClient struct {
RedirectUris []string
}

// This is to implement jwt.Clock and provide the Now() function. An empty
// instance of this struct will be passed to the jwt.WithClock() function so it
// knows how to verify the timestamps.
type nowClock struct {
jwt.Clock
}

// This function returns whatever "now" is for jwt.Clock. We simply return
// time.Now(). It would be nice if we could just pass time.Now() to the
// jwt.WithClock function, but it forces us to have something that implements
// the jwt.Clock interface to do it.
func (nc nowClock) Now() time.Time {
return time.Now()
}

func (client *OAuthClient) CreateOAuthClient(registerUrl string) ([]byte, error) {
// hydra endpoint: POST /clients
data := []byte(`{
Expand All @@ -30,7 +63,7 @@ func (client *OAuthClient) CreateOAuthClient(registerUrl string) ([]byte, error)
"state": "12345678910"
}`)

req, err := http.NewRequest("POST", registerUrl, bytes.NewBuffer(data))
req, err := http.NewRequest(http.MethodPost, registerUrl, bytes.NewBuffer(data))
if err != nil {
return nil, fmt.Errorf("failed to make request: %v", err)
}
Expand Down Expand Up @@ -77,7 +110,7 @@ func (client *OAuthClient) AuthorizeOAuthClient(authorizeUrl string) ([]byte, er
"Content-Type": {"application/x-www-form-urlencoded"},
}

req, err := http.NewRequest("POST", authorizeUrl, bytes.NewBuffer(body))
req, err := http.NewRequest(http.MethodPost, authorizeUrl, bytes.NewBuffer(body))
req.Header = headers
if err != nil {
return nil, fmt.Errorf("failed to make request: %v", err)
Expand All @@ -101,7 +134,7 @@ func (client *OAuthClient) PerformTokenGrant(remoteUrl string) (string, error) {
"Content-Type": {"application/x-www-form-urlencoded"},
"Authorization": {"Bearer " + client.RegistrationAccessToken},
}
req, err := http.NewRequest("POST", remoteUrl, bytes.NewBuffer([]byte(body)))
req, err := http.NewRequest(http.MethodPost, remoteUrl, bytes.NewBuffer([]byte(body)))
req.Header = headers
if err != nil {
return "", fmt.Errorf("failed to make request: %s", err)
Expand Down Expand Up @@ -132,3 +165,141 @@ func QuoteArrayStrings(arr []string) []string {
}
return arr
}

// RequestClientCreds performs the requests to the OAuth2 server to obtain an
// access token for this client (BSS).
//
// 1. Register as OAuth2 client.
// 2. Authorize OAuth2 client that was created.
// 3. Obtain access token if OAuth2 client is authorized.
//
// Returns the OAuthClient struct containing the client ID, secret, etc. as well
// as the access token and an error if one occurred.
func (client *OAuthClient) RequestClientCreds() (accessToken string, err error) {
var (
url string
resp []byte
)

url = oauth2AdminBaseURL + "/admin/clients"
log.Printf("Attempting to register OAuth2 client")
debugf("Sending request to %s", url)
resp, err = client.CreateOAuthClient(url)
if err != nil {
err = fmt.Errorf("Failed to register OAuth2 client: %v", err)
debugf("Response: %v", string(resp))
return
}
log.Printf("Successfully registered OAuth2 client")
debugf("Client ID: %s", client.Id)

url = oauth2AdminBaseURL + "/oauth2/auth"
log.Printf("Attempting to authorize OAuth2 client")
debugf("Sending request to %s", url)
_, err = client.AuthorizeOAuthClient(url)
if err != nil {
err = fmt.Errorf("Failed to authorize OAuth2 client: %v", err)
debugf("Response: %v", string(resp))
return
}
log.Printf("Successfully authorized OAuth2 client")

url = oauth2PublicBaseURL + "/oauth2/token"
log.Printf("Attempting to fetch token from authorization server")
debugf("Sending request to %s", url)
accessToken, err = client.PerformTokenGrant(url)
if err != nil {
err = fmt.Errorf("Failed to fetch token from authorization server: %v", err)
return
}
log.Printf("Successfully fetched token")

return
}

// PollClientCreds tries retryCount times every retryInterval seconds to request
// client credentials and an access token (JWT) from the OAuth2 server. If
// attempts are exhausted or an invalid retryInterval is passed, an error is
// returned. If a JWT was successfully obtained, nil is returned.
func (client *OAuthClient) PollClientCreds(retryCount, retryInterval uint64) error {
retryDuration, err := time.ParseDuration(fmt.Sprintf("%ds", retryInterval))
if err != nil {
return fmt.Errorf("Invalid retry interval: %v", err)
}
for i := uint64(0); i < retryCount; i++ {
log.Printf("Attempting to obtain access token (attempt %d/%d)", i+1, retryCount)
token, err := client.RequestClientCreds()
if err != nil {
log.Printf("Failed to obtain client credentials and token: %v", err)
time.Sleep(retryDuration)
continue
}
log.Printf("Successfully obtained client credentials and token with %d attempts", i+1)
accessToken = token
return nil
}
log.Printf("Exhausted attempts to obtain client credentials and token")
return fmt.Errorf("Exhausted %d attempts at obtaining client credentials and token")
}

// JWTTestAndRefresh tests the current JWT. If either a parsing error occurs
// with it or the JWT is invalid, it attempts to fetch a new one. If all of this
// succeeds, nil is returned. Otherwise, an error is returned.
func (client *OAuthClient) JWTTestAndRefresh() (err error) {
var (
jwtIsValid bool
reason error
)

log.Printf("Validating JWT")
if accessToken != "" {
jwtIsValid, reason, err = JWTIsValid(accessToken)
if err != nil {
log.Printf("Unable to parse JWT, attempting to fetch a new one")
} else if !jwtIsValid {
log.Printf("JWT invalid, reason: %v", reason)
log.Printf("Attempting to fetch a new one")
} else {
log.Printf("JWT is valid")
return nil
}
} else {
log.Printf("No JWT detected, fetching a new one")
}

err = client.PollClientCreds(authRetryCount, authRetryWait)
if err != nil {
log.Printf("Polling for OAuth2 client credentials failed")
return fmt.Errorf("Failed to get access token: %v", err)
}
log.Printf("Successfully fetched new JWT")
return nil
}

// JWTIsValid takes a string representing a JWT and validates that it is not
// expired. If the JWT is invalid (timestamp(s) is/are out of range), jwtValid
// is set to false, reason is set to the reason why the JWT is not valid, and
// err is nil. If the JWT is valid (timestamps are all in range), jwtValid is
// set to true, reason is nil, and err is nil.
func JWTIsValid(jwtStr string) (jwtValid bool, reason, err error) {
var token jwt.Token
token, err = jwt.Parse([]byte(jwtStr))
if err != nil {
err = fmt.Errorf("failed to parse JWT string: %v", err)
return
}

// Right now, we only validate the issued at, expiry, and not before
// fields.
// TODO: Add full validation.
reason = jwt.Validate(token, jwt.WithClock(nowClock{}))
debugf("JWT valid between %v and %v", token.NotBefore(), token.Expiration())
debugf("Current time: %v", time.Now())
if reason == nil {
jwtValid = true
} else {
jwtValid = false
}

return
}
Loading

0 comments on commit 8f40cfa

Please sign in to comment.