Skip to content

Commit

Permalink
Merge pull request #17 from davidallendj/bss-auth
Browse files Browse the repository at this point in the history
Changed router to go-chi and added authentication middleware
  • Loading branch information
davidallendj authored Feb 22, 2024
2 parents 11d3aaa + 01ca9c0 commit 1e9b69e
Show file tree
Hide file tree
Showing 5 changed files with 308 additions and 30 deletions.
39 changes: 36 additions & 3 deletions cmd/boot-script-service/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ const kvDefaultRetryCount uint64 = 10
const kvDefaultRetryWait uint64 = 5
const sqlDefaultRetryCount uint64 = 10
const sqlDefaultRetryWait uint64 = 5
const authDefaultRetryCount uint64 = 10

var (
httpListen = ":27778"
Expand Down Expand Up @@ -93,6 +94,9 @@ var (
sqlRetryWait = sqlDefaultRetryWait
notifier *ScnNotifier
useSQL = false // Use ETCD by default
requireAuth = false
authRetryCount = authDefaultRetryCount
jwksURL = ""
sqlDbOpts = ""
spireServiceURL = "https://spire-tokens.spire:54440"
)
Expand Down Expand Up @@ -146,7 +150,6 @@ func kvDefaultURL() string {
return ret
}


func kvDefaultRetryConfig() (retryCount uint64, retryWait uint64, err error) {
retryCount = kvDefaultRetryCount
retryWait = kvDefaultRetryWait
Expand Down Expand Up @@ -296,6 +299,18 @@ func parseEnvVars() error {
if parseErr != nil {
errList = append(errList, fmt.Errorf("BSS_ENDPOINT_HOST: %q", parseErr))
}
parseErr = parseEnv("BSS_AUTH_RETRY_COUNT", &authRetryCount)
if parseErr != nil {
errList = append(errList, fmt.Errorf("BSS_AUTH_RETRY_COUNT: %q", parseErr))
}
parseErr = parseEnv("BSS_AUTH_REQUIRED", &requireAuth)
if parseErr != nil {
errList = append(errList, fmt.Errorf("BSS_AUTH_REQUIRED: %q", parseErr))
}
parseErr = parseEnv("BSS_JWKS_URL", &jwksURL)
if parseErr != nil {
errList = append(errList, fmt.Errorf("BSS_JWKS_URL: %q", parseErr))
}

//
// Etcd environment variables
Expand Down Expand Up @@ -390,12 +405,15 @@ func parseCmdLine() {
flag.StringVar(&bssdbName, "postgres-dbname", bssdbName, "(BSS_DBNAME) Postgres database name")
flag.StringVar(&sqlUser, "postgres-username", sqlUser, "(BSS_DBUSER) Postgres username")
flag.StringVar(&sqlPass, "postgres-password", sqlPass, "(BSS_DBPASS) Postgres password")
flag.StringVar(&jwksURL, "jwks-url", jwksURL, "(BSS_JWKS_URL) Set the JWKS URL to fetch the public key for authorization")
flag.BoolVar(&insecure, "insecure", insecure, "(BSS_INSECURE) Don't enforce https certificate security")
flag.BoolVar(&debugFlag, "debug", debugFlag, "(BSS_DEBUG) Enable debug output")
flag.BoolVar(&useSQL, "postgres", useSQL, "(BSS_USESQL) Use Postgres instead of ETCD")
flag.BoolVar(&requireAuth, "require-auth", requireAuth, "(BSS_REQUIRE_AUTH) Require JWTs authorization to allow using API endpoint")
flag.UintVar(&retryDelay, "retry-delay", retryDelay, "(BSS_RETRY_DELAY) Retry delay in seconds")
flag.UintVar(&hsmRetrievalDelay, "hsm-retrieval-delay", hsmRetrievalDelay, "(BSS_HSM_RETRIEVAL_DELAY) SM Retrieval delay in seconds")
flag.UintVar(&sqlPort, "postgres-port", sqlPort, "(BSS_DBPORT) Postgres port")
flag.Uint64Var(&authRetryCount, "auth-retry-count", authRetryCount, "(BSS_AUTH_RETRY_COUNT) Retry fetching JWKS public key set")
flag.Uint64Var(&sqlRetryCount, "postgres-retry-count", sqlRetryCount, "(BSS_SQL_RETRY_COUNT) Amount of times to retry connecting to Postgres")
flag.Uint64Var(&sqlRetryWait, "postgres-retry-wait", sqlRetryCount, "(BSS_SQL_RETRY_WAIT) Interval in seconds between connection attempts to Postgres")
flag.Parse()
Expand All @@ -414,7 +432,22 @@ func main() {
serviceName = sn
}
log.Printf("Service %s started", serviceName)
initHandlers()

router := initHandlers()

// try and fetch JWKS from issuer
if requireAuth {
for i := uint64(0); i <= authRetryCount; i++ {
err := loadPublicKeyFromURL(jwksURL)
if err != nil {
log.Printf("failed to initialize auth token: %v", err)
time.Sleep(5 * time.Second)
continue
}
log.Printf("Initialized the auth token successfully.")
break
}
}

var svcOpts string
if insecure {
Expand Down Expand Up @@ -457,5 +490,5 @@ func main() {
// NOTE: Should this be fatal??? Right now, we will continue.
log.Printf("WARNING: Spire join token service %s access failure: %s", spireServiceURL, err)
}
log.Fatal(http.ListenAndServe(httpListen, nil))
log.Fatal(http.ListenAndServe(httpListen, router))
}
94 changes: 77 additions & 17 deletions cmd/boot-script-service/routers.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,37 +36,97 @@
package main

import (
"context"
"fmt"
base "github.com/Cray-HPE/hms-base"
"net/http"
"time"

base "github.com/Cray-HPE/hms-base"
"github.com/go-chi/chi/middleware"
"github.com/go-chi/chi/v5"
"github.com/go-chi/jwtauth/v5"
"github.com/lestrrat-go/jwx/jwa"
"github.com/lestrrat-go/jwx/jwk"
)

const (
baseEndpoint = "/boot/v1"
notifierEndpoint = baseEndpoint + "/scn"
// We don't use the baseEndpoint here because cloud-init doesn't like them
metaDataRoute = "/meta-data"
userDataRoute = "/user-data"
phoneHomeRoute = "/phone-home"
metaDataRoute = "/meta-data"
userDataRoute = "/user-data"
phoneHomeRoute = "/phone-home"
)

func initHandlers() {
http.HandleFunc(baseEndpoint+"/", Index)
// config
http.HandleFunc(baseEndpoint+"/bootparameters", bootParameters)
var (
tokenAuth *jwtauth.JWTAuth
)

func loadPublicKeyFromURL(url string) error {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
set, err := jwk.Fetch(ctx, url)
if err != nil {
return fmt.Errorf("%v", err)
}
for it := set.Iterate(context.Background()); it.Next(context.Background()); {
pair := it.Pair()
key := pair.Value.(jwk.Key)

var rawkey interface{}
if err := key.Raw(&rawkey); err != nil {
continue
}

tokenAuth = jwtauth.New(jwa.RS256.String(), nil, rawkey)
return nil
}

return fmt.Errorf("failed to load public key: %v", err)
}

func initHandlers() *chi.Mux {
router := chi.NewRouter()
router.Use(middleware.RequestID)
router.Use(middleware.RealIP)
router.Use(middleware.Logger)
router.Use(middleware.Recoverer)
router.Use(middleware.StripSlashes)
router.Use(middleware.Timeout(60 * time.Second))
if requireAuth {
router.Group(func(r chi.Router) {
r.Use(
jwtauth.Verifier(tokenAuth),
jwtauth.Authenticator(tokenAuth),
)

// protected routes if using auth
r.HandleFunc(baseEndpoint+"/", Index)
r.HandleFunc(baseEndpoint+"/bootparameters", bootParameters)
})
} else {
// public routes without auth
router.HandleFunc(baseEndpoint+"/", Index)
router.HandleFunc(baseEndpoint+"/bootparameters", bootParameters)
}
// every thing else is public
// boot
http.HandleFunc(baseEndpoint+"/bootscript", bootScript)
http.HandleFunc(baseEndpoint+"/hosts", hosts)
http.HandleFunc(baseEndpoint+"/dumpstate", dumpstate)
http.HandleFunc(baseEndpoint+"/service/", service)
router.HandleFunc(baseEndpoint+"/bootscript", bootScript)
router.HandleFunc(baseEndpoint+"/hosts", hosts)
router.HandleFunc(baseEndpoint+"/dumpstate", dumpstate)
router.HandleFunc(baseEndpoint+"/service/status", serviceStatusResponse)
router.HandleFunc(baseEndpoint+"/service/version", serviceVersionResponse)
router.HandleFunc(baseEndpoint+"/service/hsm", serviceHSMResponse)
router.HandleFunc(baseEndpoint+"/service/etcd", serviceETCDResponse)
// cloud-init
http.HandleFunc(metaDataRoute, metaDataGet)
http.HandleFunc(userDataRoute, userDataGet)
http.HandleFunc(phoneHomeRoute, phoneHomePost)
router.HandleFunc(metaDataRoute, metaDataGet)
router.HandleFunc(userDataRoute, userDataGet)
router.HandleFunc(phoneHomeRoute, phoneHomePost)
// notifications
http.HandleFunc(notifierEndpoint, scn)
router.HandleFunc(notifierEndpoint, scn)
// endpoint-access
http.HandleFunc(baseEndpoint+"/endpoint-history", endpointHistoryGet)
router.HandleFunc(baseEndpoint+"/endpoint-history", endpointHistoryGet)
return router
}

func Index(w http.ResponseWriter, r *http.Request) {
Expand Down
86 changes: 86 additions & 0 deletions cmd/boot-script-service/serviceAPI.go
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,92 @@ func serviceStatusAPI(w http.ResponseWriter, req *http.Request) {
fmt.Fprintln(w, string(out))
}

func serviceStatusResponse(w http.ResponseWriter, req *http.Request) {
var bssStatus serviceStatus
var httpStatus = http.StatusOK

bssStatus.Status = "running"

w.WriteHeader(httpStatus)
out, _ := json.Marshal(bssStatus)
fmt.Fprintln(w, string(out))
}

func serviceVersionResponse(w http.ResponseWriter, req *http.Request) {
var bssStatus serviceStatus
var httpStatus = http.StatusOK

dat, err := ioutil.ReadFile(".version")
if err != nil {
dat, err = ioutil.ReadFile("../../.version")
if err != nil {
httpStatus = http.StatusInternalServerError
dat = []byte("error")
log.Printf("Cannot read version file: %s", err)
}
}
bssStatus.Version = strings.TrimSpace(string(dat))

w.WriteHeader(httpStatus)
out, _ := json.Marshal(bssStatus)
fmt.Fprintln(w, string(out))
}

func serviceHSMResponse(w http.ResponseWriter, req *http.Request) {
var bssStatus serviceStatus
var httpStatus = http.StatusOK

bssStatus.HSMStatus = "connected"
url := smBaseURL + "/service/values/class"
rsp, err := smClient.Get(url)
if err != nil {
httpStatus = http.StatusInternalServerError
bssStatus.HSMStatus = "error"
log.Printf("Cannot connect to HSM: %s", err)
} else {
_, err = ioutil.ReadAll(rsp.Body)
if err != nil {
httpStatus = http.StatusInternalServerError
bssStatus.HSMStatus = "error"
log.Printf("Cannot read /service/values/class response from HSM: %s", err)
}
rsp.Body.Close()
}

w.WriteHeader(httpStatus)
out, _ := json.Marshal(bssStatus)
fmt.Fprintln(w, string(out))
}

func serviceETCDResponse(w http.ResponseWriter, req *http.Request) {
var bssStatus serviceStatus
var httpStatus = http.StatusOK

bssStatus.EctdStatus = "connected"
randnum := rand.Intn(255)
err := etcdTestStore(randnum)
if err != nil {
httpStatus = http.StatusInternalServerError
bssStatus.EctdStatus = "error"
log.Printf("Test store to etcd failed: %s", err)
} else {
ret, err := etcdTestGet()
if err != nil || ret != randnum {
httpStatus = http.StatusInternalServerError
bssStatus.EctdStatus = "error"
if err != nil {
log.Printf("Test read from etcd failed: %s", err)
} else {
log.Printf("Test read from etcd miscompare: Expected %d, Actual %d", randnum, ret)
}
}
}

w.WriteHeader(httpStatus)
out, _ := json.Marshal(bssStatus)
fmt.Fprintln(w, string(out))
}

func etcdTestStore(testId int) error {
data, err := json.Marshal(testId)
err = kvstore.Store("/bss/etcdTest", string(data))
Expand Down
20 changes: 17 additions & 3 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,11 @@ require (

require (
github.com/OpenCHAMI/smd/v2 v2.12.15
github.com/go-chi/chi/v5 v5.0.11
github.com/go-chi/jwtauth v1.2.0
github.com/go-chi/jwtauth/v5 v5.3.0
github.com/golang-migrate/migrate/v4 v4.16.2
github.com/lestrrat-go/jwx v1.2.28
)

require (
Expand All @@ -27,8 +31,10 @@ require (
github.com/aws/aws-sdk-go v1.34.0 // indirect
github.com/cenkalti/backoff/v3 v3.2.2 // indirect
github.com/coreos/etcd v3.3.13+incompatible // indirect
github.com/decred/dcrd/dcrec/secp256k1/v4 v4.2.0 // indirect
github.com/fsnotify/fsnotify v1.6.0 // indirect
github.com/go-jose/go-jose/v3 v3.0.0 // indirect
github.com/goccy/go-json v0.10.2 // indirect
github.com/gogo/protobuf v1.3.2 // indirect
github.com/golang/protobuf v1.5.2 // indirect
github.com/hashicorp/errwrap v1.1.0 // indirect
Expand All @@ -42,17 +48,25 @@ require (
github.com/hashicorp/hcl v1.0.1-vault-5 // indirect
github.com/hashicorp/vault/api v1.9.2 // indirect
github.com/jmespath/go-jmespath v0.4.0 // indirect
github.com/lestrrat-go/backoff/v2 v2.0.8 // indirect
github.com/lestrrat-go/blackmagic v1.0.2 // indirect
github.com/lestrrat-go/httpcc v1.0.1 // indirect
github.com/lestrrat-go/httprc v1.0.4 // indirect
github.com/lestrrat-go/iter v1.0.2 // indirect
github.com/lestrrat-go/jwx/v2 v2.0.17 // indirect
github.com/lestrrat-go/option v1.0.1 // indirect
github.com/mitchellh/go-homedir v1.1.0 // indirect
github.com/mitchellh/mapstructure v1.5.0 // indirect
github.com/pkg/errors v0.9.1 // indirect
github.com/ryanuber/go-glob v1.0.0 // indirect
github.com/segmentio/asm v1.2.0 // indirect
github.com/sirupsen/logrus v1.9.3 // indirect
go.etcd.io/etcd v3.3.13+incompatible // indirect
go.uber.org/atomic v1.11.0 // indirect
golang.org/x/crypto v0.12.0 // indirect
golang.org/x/crypto v0.17.0 // indirect
golang.org/x/net v0.14.0 // indirect
golang.org/x/sys v0.11.0 // indirect
golang.org/x/text v0.12.0 // indirect
golang.org/x/sys v0.15.0 // indirect
golang.org/x/text v0.14.0 // indirect
golang.org/x/time v0.3.0 // indirect
google.golang.org/genproto v0.0.0-20230110181048-76db0878b65f // indirect
google.golang.org/grpc v1.51.0 // indirect
Expand Down
Loading

0 comments on commit 1e9b69e

Please sign in to comment.