Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Router on Postgres #388

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ go 1.20
require (
github.com/getsentry/sentry-go v0.23.0
github.com/globalsign/mgo v0.0.0-20181015135952-eeefdecb41b8
github.com/lib/pq v1.10.9
github.com/onsi/ginkgo/v2 v2.11.0
github.com/onsi/gomega v1.27.10
github.com/prometheus/client_golang v1.16.0
Expand Down
2 changes: 2 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@ github.com/josharian/intern v1.0.0/go.mod h1:5DoeVV0s6jJacbCEi61lwdGj/aVlrQvzHFF
github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE=
github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE=
github.com/lib/pq v1.10.9 h1:YXG7RB+JIjhP29X+OtkiDnYaXQwpS4JEWq7dtCCRUEw=
github.com/lib/pq v1.10.9/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o=
github.com/mailru/easyjson v0.7.7 h1:UGYAvKxe3sBsEDzO8ZeWOSlIQfWFlxbzLZe7hwFURr0=
github.com/mailru/easyjson v0.7.7/go.mod h1:xzfreul335JAWq5oZzymOObrkdz5UnU4kGfJJLY9Nlc=
github.com/matttproud/golang_protobuf_extensions v1.0.4 h1:mmDVorXM7PCGKw94cs5zkfA9PSy5pEvNWRP0ET0TIVo=
Expand Down
176 changes: 80 additions & 96 deletions lib/router.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package router

import (
"database/sql"
"fmt"
"net/http"
"net/url"
Expand All @@ -13,8 +14,7 @@ import (
"github.com/alphagov/router/handlers"
"github.com/alphagov/router/logger"
"github.com/alphagov/router/triemux"
"github.com/globalsign/mgo"
"github.com/globalsign/mgo/bson"
"github.com/lib/pq"
)

const (
Expand All @@ -38,18 +38,18 @@ const (
// come from, Route and Backend should not contain bson fields.
// MongoReplicaSet, MongoReplicaSetMember etc. should move out of this module.
type Router struct {
mux *triemux.Mux
lock sync.RWMutex
mongoReadToOptime bson.MongoTimestamp
logger logger.Logger
opts Options
ReloadChan chan bool
mux *triemux.Mux
lock sync.RWMutex
logger logger.Logger
opts Options
ReloadChan chan bool
}

type Options struct {
MongoURL string
MongoDBName string
MongoPollInterval time.Duration
postgresURL string
postgresDbName string
listener *pq.Listener
dbPollInterval time.Duration
BackendConnTimeout time.Duration
BackendHeaderTimeout time.Duration
LogFileName string
Expand Down Expand Up @@ -92,7 +92,7 @@ func RegisterMetrics(r prometheus.Registerer) {
// NewRouter returns a new empty router instance. You will need to call
// SelfUpdateRoutes() to initialise the self-update process for routes.
func NewRouter(o Options) (rt *Router, err error) {
logInfo("router: using mongo poll interval:", o.MongoPollInterval)
logInfo("router: using database poll interval:", o.dbPollInterval)
logInfo("router: using backend connect timeout:", o.BackendConnTimeout)
logInfo("router: using backend header timeout:", o.BackendHeaderTimeout)

Expand All @@ -102,18 +102,26 @@ func NewRouter(o Options) (rt *Router, err error) {
}
logInfo("router: logging errors as JSON to", o.LogFileName)

mongoReadToOptime, err := bson.NewMongoTimestamp(time.Date(1970, time.January, 1, 0, 0, 0, 0, time.UTC), 1)
listenerProblemReporter := func(event pq.ListenerEventType, err error) {
if err != nil {
logWarn(fmt.Sprintf("pq: error creating listener for PSQL notify channel: %v)", err))
return
}
}

listener := pq.NewListener(o.postgresURL, 10*time.Second, time.Minute, listenerProblemReporter)

err = listener.Listen("events")
if err != nil {
return nil, err
panic(err)
}

reloadChan := make(chan bool, 1)
rt = &Router{
mux: triemux.NewMux(),
mongoReadToOptime: mongoReadToOptime,
logger: l,
opts: o,
ReloadChan: reloadChan,
mux: triemux.NewMux(),
logger: l,
opts: o,
ReloadChan: reloadChan,
}

go rt.pollAndReload()
Expand Down Expand Up @@ -150,9 +158,9 @@ func (rt *Router) ServeHTTP(w http.ResponseWriter, req *http.Request) {
}

func (rt *Router) SelfUpdateRoutes() {
logInfo(fmt.Sprintf("router: starting self-update process, polling for route changes every %v", rt.opts.MongoPollInterval))
logInfo(fmt.Sprintf("router: starting self-update process, polling for route changes every %v", rt.opts.dbPollInterval))

tick := time.Tick(rt.opts.MongoPollInterval)
tick := time.Tick(rt.opts.dbPollInterval)
for range tick {
logDebug("router: polling MongoDB for changes")

Expand All @@ -172,32 +180,19 @@ func (rt *Router) pollAndReload() {
}
}()

logDebug("mgo: connecting to", rt.opts.MongoURL)
logDebug("pq: connecting to", rt.opts.postgresURL)

sess, err := mgo.Dial(rt.opts.MongoURL)
sess, err := sql.Open("postgres", rt.opts.postgresURL)
if err != nil {
logWarn(fmt.Sprintf("mgo: error connecting to MongoDB, skipping update (error: %v)", err))
logWarn(fmt.Sprintf("pq: error connecting to PSQL database, skipping update (error: %v)", err))
return
}

defer sess.Close()
sess.SetMode(mgo.SecondaryPreferred, true)

currentMongoInstance, err := rt.getCurrentMongoInstance(sess.DB("admin"))
if err != nil {
logWarn(err)
return
}

logDebug("mgo: communicating with replica set member", currentMongoInstance.Name)

logDebug("router: polled mongo instance is ", currentMongoInstance.Name)
logDebug("router: polled mongo optime is ", currentMongoInstance.Optime)
logDebug("router: current read-to mongo optime is ", rt.mongoReadToOptime)

if rt.shouldReload(currentMongoInstance) {
if rt.shouldReload(rt.opts.listener) {
logDebug("router: updates found")
rt.reloadRoutes(sess.DB(rt.opts.MongoDBName), currentMongoInstance.Optime)
rt.reloadRoutes(sess)
} else {
logDebug("router: no updates found")
}
Expand All @@ -212,7 +207,7 @@ type mongoDatabase interface {
// reloadRoutes reloads the routes for this Router instance on the fly. It will
// create a new proxy mux, load applications (backends) and routes into it, and
// then flip the "mux" pointer in the Router.
func (rt *Router) reloadRoutes(db *mgo.Database, currentOptime bson.MongoTimestamp) {
func (rt *Router) reloadRoutes(db *sql.DB) {
defer func() {
// increment this metric regardless of whether the route reload succeeded
routeReloadCountMetric.Inc()
Expand All @@ -225,103 +220,86 @@ func (rt *Router) reloadRoutes(db *mgo.Database, currentOptime bson.MongoTimesta
logger.NotifySentry(logger.ReportableError{Error: err})

routeReloadErrorCountMetric.Inc()
} else {
rt.mongoReadToOptime = currentOptime
}
}()

logInfo("router: reloading routes")
newmux := triemux.NewMux()

backends := rt.loadBackends(db.C("backends"))
loadRoutes(db.C("routes"), newmux, backends)
routeCount := newmux.RouteCount()
backends := rt.loadBackends(db)
loadRoutes(db, newmux, backends)

rt.lock.Lock()
rt.mux = newmux
rt.lock.Unlock()

logInfo(fmt.Sprintf("router: reloaded %d routes", routeCount))
routesCountMetric.Set(float64(routeCount))
}

func (rt *Router) getCurrentMongoInstance(db mongoDatabase) (MongoReplicaSetMember, error) {
replicaSetStatus := bson.M{}
logInfo(fmt.Sprintf("router: reloaded %d routes", rt.mux.RouteCount()))

if err := db.Run("replSetGetStatus", &replicaSetStatus); err != nil {
return MongoReplicaSetMember{}, fmt.Errorf("router: couldn't get replica set status from MongoDB, skipping update (error: %w)", err)
}

replicaSetStatusBytes, err := bson.Marshal(replicaSetStatus)
if err != nil {
return MongoReplicaSetMember{}, fmt.Errorf("router: couldn't marshal replica set status from MongoDB, skipping update (error: %w)", err)
}

replicaSet := MongoReplicaSet{}
err = bson.Unmarshal(replicaSetStatusBytes, &replicaSet)
if err != nil {
return MongoReplicaSetMember{}, fmt.Errorf("router: couldn't unmarshal replica set status from MongoDB, skipping update (error: %w)", err)
}
routesCountMetric.Set(float64(rt.mux.RouteCount()))
}

currentInstance := make([]MongoReplicaSetMember, 0)
for _, instance := range replicaSet.Members {
if instance.Current {
currentInstance = append(currentInstance, instance)
func (rt *Router) shouldReload(listener *pq.Listener) bool {
select {
case n := <-listener.Notify:
// n.Extra contains the payload from the notification
logInfo("notification:", n.Channel)
return true
default:
if err := listener.Ping(); err != nil {
panic(err)
}
return false
}

logDebug("router: MongoDB instances", currentInstance)

if len(currentInstance) != 1 {
return MongoReplicaSetMember{}, fmt.Errorf("router: did not find exactly one current MongoDB instance, skipping update (current instances found: %d)", len(currentInstance))
}

return currentInstance[0], nil
}

func (rt *Router) shouldReload(currentMongoInstance MongoReplicaSetMember) bool {
return currentMongoInstance.Optime > rt.mongoReadToOptime
}

// loadBackends is a helper function which loads backends from the
// passed mongo collection, constructs a Handler for each one, and returns
// them in map keyed on the backend_id
func (rt *Router) loadBackends(c *mgo.Collection) (backends map[string]http.Handler) {
ffunc (rt *Router) loadBackends(db *sql.DB) (backends map[string]http.Handler) {
backend := &Backend{}
backends = make(map[string]http.Handler)

iter := c.Find(nil).Iter()
rows, err := db.Query("SELECT * FROM backends")
if err != nil {
logWarn(fmt.Sprintf("pq: error retrieving row information from table, skipping update. (error: %v)", err))
return
}

for rows.Next() {
err := rows.Scan(&backend.BackendID, &backend.BackendURL)
if err != nil {
logWarn(fmt.Sprintf("pq: error retrieving row information from table, skipping update. (error: %v)", err))
return
}

for iter.Next(&backend) {
backendURL, err := backend.ParseURL()
if err != nil {
logWarn(fmt.Errorf("router: couldn't parse URL %s for backend %s "+
"(error: %w), skipping", backend.BackendURL, backend.BackendID, err))
logWarn(fmt.Sprintf("router: couldn't parse URL %s for backends %s "+
"(error: %v), skipping!", backend.BackendURL, backend.BackendID, err))
continue
}

backends[backend.BackendID] = handlers.NewBackendHandler(
backend.BackendID,
backendURL,
rt.opts.BackendConnTimeout,
rt.opts.BackendHeaderTimeout,
rt.backendConnectTimeout, rt.backendHeaderTimeout,
rt.logger,
)
}

if err := iter.Err(); err != nil {
panic(err)
}

return
}

// loadRoutes is a helper function which loads routes from the passed mongo
// collection and registers them with the passed proxy mux.
func loadRoutes(c *mgo.Collection, mux *triemux.Mux, backends map[string]http.Handler) {
func loadRoutes(db *sql.DB, mux *triemux.Mux, backends map[string]http.Handler) {
route := &Route{}

iter := c.Find(nil).Iter()
rows, err := db.Query("SELECT * FROM routes")
if err != nil {
logWarn(fmt.Sprintf("pq: error retrieving row information from table, skipping update. (error: %v)", err))
return
}

goneHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
http.Error(w, "410 Gone", http.StatusGone)
Expand All @@ -330,8 +308,14 @@ func loadRoutes(c *mgo.Collection, mux *triemux.Mux, backends map[string]http.Ha
http.Error(w, "503 Service Unavailable", http.StatusServiceUnavailable)
})

for iter.Next(&route) {
prefix := (route.RouteType == RouteTypePrefix)
for rows.Next() {
err := rows.Scan(&route.IncomingPath, &route.RouteType, &route.Handler, &route.Disabled, &route.BackendID, &route.RedirectTo, &route.RedirectType, &route.SegmentsMode)
if err != nil {
logWarn(fmt.Sprintf("pq: error retrieving row information from table, skipping update. (error: %v)", err))
return
}

prefix := (route.RouteType == "prefix")

// the database contains paths with % encoded routes.
// Unescape them here because the http.Request objects we match against contain the unescaped variants.
Expand Down
28 changes: 14 additions & 14 deletions main.go
Original file line number Diff line number Diff line change
Expand Up @@ -85,17 +85,17 @@ func main() {

router.EnableDebugOutput = os.Getenv("ROUTER_DEBUG") != ""
var (
pubAddr = getenv("ROUTER_PUBADDR", ":8080")
apiAddr = getenv("ROUTER_APIADDR", ":8081")
mongoURL = getenv("ROUTER_MONGO_URL", "127.0.0.1")
mongoDBName = getenv("ROUTER_MONGO_DB", "router")
mongoPollInterval = getenvDuration("ROUTER_MONGO_POLL_INTERVAL", "2s")
errorLogFile = getenv("ROUTER_ERROR_LOG", "STDERR")
tlsSkipVerify = os.Getenv("ROUTER_TLS_SKIP_VERIFY") != ""
beConnTimeout = getenvDuration("ROUTER_BACKEND_CONNECT_TIMEOUT", "1s")
beHeaderTimeout = getenvDuration("ROUTER_BACKEND_HEADER_TIMEOUT", "20s")
feReadTimeout = getenvDuration("ROUTER_FRONTEND_READ_TIMEOUT", "60s")
feWriteTimeout = getenvDuration("ROUTER_FRONTEND_WRITE_TIMEOUT", "60s")
pubAddr = getenv("ROUTER_PUBADDR", ":8080")
apiAddr = getenv("ROUTER_APIADDR", ":8081")
postgresURL = getenv("DATABASE_URL", "postgresql://postgres@127.0.0.1:27017/router?sslmode=disable")
postgresDbName = getenv("DATABASE_NAME", "router")
dbPollInterval = getenv("ROUTER_POLL_INTERVAL", "2s")
errorLogFile = getenv("ROUTER_ERROR_LOG", "STDERR")
tlsSkipVerify = os.Getenv("ROUTER_TLS_SKIP_VERIFY") != ""
beConnTimeout = getenvDuration("ROUTER_BACKEND_CONNECT_TIMEOUT", "1s")
beHeaderTimeout = getenvDuration("ROUTER_BACKEND_HEADER_TIMEOUT", "20s")
feReadTimeout = getenvDuration("ROUTER_FRONTEND_READ_TIMEOUT", "60s")
feWriteTimeout = getenvDuration("ROUTER_FRONTEND_WRITE_TIMEOUT", "60s")
)

log.Printf("using frontend read timeout: %v", feReadTimeout)
Expand All @@ -111,9 +111,9 @@ func main() {
router.RegisterMetrics(prometheus.DefaultRegisterer)

rout, err := router.NewRouter(router.Options{
MongoURL: mongoURL,
MongoDBName: mongoDBName,
MongoPollInterval: mongoPollInterval,
MongoURL: postgresURL,
MongoDBName: postgresDbName,
MongoPollInterval: dbPollInterval,
BackendConnTimeout: beConnTimeout,
BackendHeaderTimeout: beHeaderTimeout,
LogFileName: errorLogFile,
Expand Down