diff --git a/go.mod b/go.mod index 4607a213..70512282 100644 --- a/go.mod +++ b/go.mod @@ -5,6 +5,7 @@ go 1.20 require ( github.com/getsentry/sentry-go v0.23.0 github.com/globalsign/mgo v0.0.0-20181015135952-eeefdecb41b8 + github.com/lib/pq v1.10.9 github.com/onsi/ginkgo/v2 v2.11.0 github.com/onsi/gomega v1.27.10 github.com/prometheus/client_golang v1.16.0 diff --git a/go.sum b/go.sum index 6f44c516..29818207 100644 --- a/go.sum +++ b/go.sum @@ -34,6 +34,8 @@ github.com/josharian/intern v1.0.0/go.mod h1:5DoeVV0s6jJacbCEi61lwdGj/aVlrQvzHFF github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= +github.com/lib/pq v1.10.9 h1:YXG7RB+JIjhP29X+OtkiDnYaXQwpS4JEWq7dtCCRUEw= +github.com/lib/pq v1.10.9/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o= github.com/mailru/easyjson v0.7.7 h1:UGYAvKxe3sBsEDzO8ZeWOSlIQfWFlxbzLZe7hwFURr0= github.com/mailru/easyjson v0.7.7/go.mod h1:xzfreul335JAWq5oZzymOObrkdz5UnU4kGfJJLY9Nlc= github.com/matttproud/golang_protobuf_extensions v1.0.4 h1:mmDVorXM7PCGKw94cs5zkfA9PSy5pEvNWRP0ET0TIVo= diff --git a/lib/router.go b/lib/router.go index 03e56545..bd1dbfd9 100644 --- a/lib/router.go +++ b/lib/router.go @@ -1,6 +1,7 @@ package router import ( + "database/sql" "fmt" "net/http" "net/url" @@ -13,8 +14,7 @@ import ( "github.com/alphagov/router/handlers" "github.com/alphagov/router/logger" "github.com/alphagov/router/triemux" - "github.com/globalsign/mgo" - "github.com/globalsign/mgo/bson" + "github.com/lib/pq" ) const ( @@ -38,18 +38,18 @@ const ( // come from, Route and Backend should not contain bson fields. // MongoReplicaSet, MongoReplicaSetMember etc. should move out of this module. type Router struct { - mux *triemux.Mux - lock sync.RWMutex - mongoReadToOptime bson.MongoTimestamp - logger logger.Logger - opts Options - ReloadChan chan bool + mux *triemux.Mux + lock sync.RWMutex + logger logger.Logger + opts Options + ReloadChan chan bool } type Options struct { - MongoURL string - MongoDBName string - MongoPollInterval time.Duration + postgresURL string + postgresDbName string + listener *pq.Listener + dbPollInterval time.Duration BackendConnTimeout time.Duration BackendHeaderTimeout time.Duration LogFileName string @@ -92,7 +92,7 @@ func RegisterMetrics(r prometheus.Registerer) { // NewRouter returns a new empty router instance. You will need to call // SelfUpdateRoutes() to initialise the self-update process for routes. func NewRouter(o Options) (rt *Router, err error) { - logInfo("router: using mongo poll interval:", o.MongoPollInterval) + logInfo("router: using database poll interval:", o.dbPollInterval) logInfo("router: using backend connect timeout:", o.BackendConnTimeout) logInfo("router: using backend header timeout:", o.BackendHeaderTimeout) @@ -102,18 +102,26 @@ func NewRouter(o Options) (rt *Router, err error) { } logInfo("router: logging errors as JSON to", o.LogFileName) - mongoReadToOptime, err := bson.NewMongoTimestamp(time.Date(1970, time.January, 1, 0, 0, 0, 0, time.UTC), 1) + listenerProblemReporter := func(event pq.ListenerEventType, err error) { + if err != nil { + logWarn(fmt.Sprintf("pq: error creating listener for PSQL notify channel: %v)", err)) + return + } + } + + listener := pq.NewListener(o.postgresURL, 10*time.Second, time.Minute, listenerProblemReporter) + + err = listener.Listen("events") if err != nil { - return nil, err + panic(err) } reloadChan := make(chan bool, 1) rt = &Router{ - mux: triemux.NewMux(), - mongoReadToOptime: mongoReadToOptime, - logger: l, - opts: o, - ReloadChan: reloadChan, + mux: triemux.NewMux(), + logger: l, + opts: o, + ReloadChan: reloadChan, } go rt.pollAndReload() @@ -150,9 +158,9 @@ func (rt *Router) ServeHTTP(w http.ResponseWriter, req *http.Request) { } func (rt *Router) SelfUpdateRoutes() { - logInfo(fmt.Sprintf("router: starting self-update process, polling for route changes every %v", rt.opts.MongoPollInterval)) + logInfo(fmt.Sprintf("router: starting self-update process, polling for route changes every %v", rt.opts.dbPollInterval)) - tick := time.Tick(rt.opts.MongoPollInterval) + tick := time.Tick(rt.opts.dbPollInterval) for range tick { logDebug("router: polling MongoDB for changes") @@ -172,32 +180,19 @@ func (rt *Router) pollAndReload() { } }() - logDebug("mgo: connecting to", rt.opts.MongoURL) + logDebug("pq: connecting to", rt.opts.postgresURL) - sess, err := mgo.Dial(rt.opts.MongoURL) + sess, err := sql.Open("postgres", rt.opts.postgresURL) if err != nil { - logWarn(fmt.Sprintf("mgo: error connecting to MongoDB, skipping update (error: %v)", err)) + logWarn(fmt.Sprintf("pq: error connecting to PSQL database, skipping update (error: %v)", err)) return } defer sess.Close() - sess.SetMode(mgo.SecondaryPreferred, true) - - currentMongoInstance, err := rt.getCurrentMongoInstance(sess.DB("admin")) - if err != nil { - logWarn(err) - return - } - - logDebug("mgo: communicating with replica set member", currentMongoInstance.Name) - logDebug("router: polled mongo instance is ", currentMongoInstance.Name) - logDebug("router: polled mongo optime is ", currentMongoInstance.Optime) - logDebug("router: current read-to mongo optime is ", rt.mongoReadToOptime) - - if rt.shouldReload(currentMongoInstance) { + if rt.shouldReload(rt.opts.listener) { logDebug("router: updates found") - rt.reloadRoutes(sess.DB(rt.opts.MongoDBName), currentMongoInstance.Optime) + rt.reloadRoutes(sess) } else { logDebug("router: no updates found") } @@ -212,7 +207,7 @@ type mongoDatabase interface { // reloadRoutes reloads the routes for this Router instance on the fly. It will // create a new proxy mux, load applications (backends) and routes into it, and // then flip the "mux" pointer in the Router. -func (rt *Router) reloadRoutes(db *mgo.Database, currentOptime bson.MongoTimestamp) { +func (rt *Router) reloadRoutes(db *sql.DB) { defer func() { // increment this metric regardless of whether the route reload succeeded routeReloadCountMetric.Inc() @@ -225,103 +220,86 @@ func (rt *Router) reloadRoutes(db *mgo.Database, currentOptime bson.MongoTimesta logger.NotifySentry(logger.ReportableError{Error: err}) routeReloadErrorCountMetric.Inc() - } else { - rt.mongoReadToOptime = currentOptime } }() logInfo("router: reloading routes") newmux := triemux.NewMux() - backends := rt.loadBackends(db.C("backends")) - loadRoutes(db.C("routes"), newmux, backends) - routeCount := newmux.RouteCount() + backends := rt.loadBackends(db) + loadRoutes(db, newmux, backends) rt.lock.Lock() rt.mux = newmux rt.lock.Unlock() - logInfo(fmt.Sprintf("router: reloaded %d routes", routeCount)) - routesCountMetric.Set(float64(routeCount)) -} - -func (rt *Router) getCurrentMongoInstance(db mongoDatabase) (MongoReplicaSetMember, error) { - replicaSetStatus := bson.M{} + logInfo(fmt.Sprintf("router: reloaded %d routes", rt.mux.RouteCount())) - if err := db.Run("replSetGetStatus", &replicaSetStatus); err != nil { - return MongoReplicaSetMember{}, fmt.Errorf("router: couldn't get replica set status from MongoDB, skipping update (error: %w)", err) - } - - replicaSetStatusBytes, err := bson.Marshal(replicaSetStatus) - if err != nil { - return MongoReplicaSetMember{}, fmt.Errorf("router: couldn't marshal replica set status from MongoDB, skipping update (error: %w)", err) - } - - replicaSet := MongoReplicaSet{} - err = bson.Unmarshal(replicaSetStatusBytes, &replicaSet) - if err != nil { - return MongoReplicaSetMember{}, fmt.Errorf("router: couldn't unmarshal replica set status from MongoDB, skipping update (error: %w)", err) - } + routesCountMetric.Set(float64(rt.mux.RouteCount())) +} - currentInstance := make([]MongoReplicaSetMember, 0) - for _, instance := range replicaSet.Members { - if instance.Current { - currentInstance = append(currentInstance, instance) +func (rt *Router) shouldReload(listener *pq.Listener) bool { + select { + case n := <-listener.Notify: + // n.Extra contains the payload from the notification + logInfo("notification:", n.Channel) + return true + default: + if err := listener.Ping(); err != nil { + panic(err) } + return false } - - logDebug("router: MongoDB instances", currentInstance) - - if len(currentInstance) != 1 { - return MongoReplicaSetMember{}, fmt.Errorf("router: did not find exactly one current MongoDB instance, skipping update (current instances found: %d)", len(currentInstance)) - } - - return currentInstance[0], nil -} - -func (rt *Router) shouldReload(currentMongoInstance MongoReplicaSetMember) bool { - return currentMongoInstance.Optime > rt.mongoReadToOptime } // loadBackends is a helper function which loads backends from the // passed mongo collection, constructs a Handler for each one, and returns // them in map keyed on the backend_id -func (rt *Router) loadBackends(c *mgo.Collection) (backends map[string]http.Handler) { +ffunc (rt *Router) loadBackends(db *sql.DB) (backends map[string]http.Handler) { backend := &Backend{} backends = make(map[string]http.Handler) - iter := c.Find(nil).Iter() + rows, err := db.Query("SELECT * FROM backends") + if err != nil { + logWarn(fmt.Sprintf("pq: error retrieving row information from table, skipping update. (error: %v)", err)) + return + } + + for rows.Next() { + err := rows.Scan(&backend.BackendID, &backend.BackendURL) + if err != nil { + logWarn(fmt.Sprintf("pq: error retrieving row information from table, skipping update. (error: %v)", err)) + return + } - for iter.Next(&backend) { backendURL, err := backend.ParseURL() if err != nil { - logWarn(fmt.Errorf("router: couldn't parse URL %s for backend %s "+ - "(error: %w), skipping", backend.BackendURL, backend.BackendID, err)) + logWarn(fmt.Sprintf("router: couldn't parse URL %s for backends %s "+ + "(error: %v), skipping!", backend.BackendURL, backend.BackendID, err)) continue } backends[backend.BackendID] = handlers.NewBackendHandler( backend.BackendID, backendURL, - rt.opts.BackendConnTimeout, - rt.opts.BackendHeaderTimeout, + rt.backendConnectTimeout, rt.backendHeaderTimeout, rt.logger, ) } - if err := iter.Err(); err != nil { - panic(err) - } - return } // loadRoutes is a helper function which loads routes from the passed mongo // collection and registers them with the passed proxy mux. -func loadRoutes(c *mgo.Collection, mux *triemux.Mux, backends map[string]http.Handler) { +func loadRoutes(db *sql.DB, mux *triemux.Mux, backends map[string]http.Handler) { route := &Route{} - iter := c.Find(nil).Iter() + rows, err := db.Query("SELECT * FROM routes") + if err != nil { + logWarn(fmt.Sprintf("pq: error retrieving row information from table, skipping update. (error: %v)", err)) + return + } goneHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { http.Error(w, "410 Gone", http.StatusGone) @@ -330,8 +308,14 @@ func loadRoutes(c *mgo.Collection, mux *triemux.Mux, backends map[string]http.Ha http.Error(w, "503 Service Unavailable", http.StatusServiceUnavailable) }) - for iter.Next(&route) { - prefix := (route.RouteType == RouteTypePrefix) + for rows.Next() { + err := rows.Scan(&route.IncomingPath, &route.RouteType, &route.Handler, &route.Disabled, &route.BackendID, &route.RedirectTo, &route.RedirectType, &route.SegmentsMode) + if err != nil { + logWarn(fmt.Sprintf("pq: error retrieving row information from table, skipping update. (error: %v)", err)) + return + } + + prefix := (route.RouteType == "prefix") // the database contains paths with % encoded routes. // Unescape them here because the http.Request objects we match against contain the unescaped variants. diff --git a/main.go b/main.go index f2d878e3..34231af7 100644 --- a/main.go +++ b/main.go @@ -85,17 +85,17 @@ func main() { router.EnableDebugOutput = os.Getenv("ROUTER_DEBUG") != "" var ( - pubAddr = getenv("ROUTER_PUBADDR", ":8080") - apiAddr = getenv("ROUTER_APIADDR", ":8081") - mongoURL = getenv("ROUTER_MONGO_URL", "127.0.0.1") - mongoDBName = getenv("ROUTER_MONGO_DB", "router") - mongoPollInterval = getenvDuration("ROUTER_MONGO_POLL_INTERVAL", "2s") - errorLogFile = getenv("ROUTER_ERROR_LOG", "STDERR") - tlsSkipVerify = os.Getenv("ROUTER_TLS_SKIP_VERIFY") != "" - beConnTimeout = getenvDuration("ROUTER_BACKEND_CONNECT_TIMEOUT", "1s") - beHeaderTimeout = getenvDuration("ROUTER_BACKEND_HEADER_TIMEOUT", "20s") - feReadTimeout = getenvDuration("ROUTER_FRONTEND_READ_TIMEOUT", "60s") - feWriteTimeout = getenvDuration("ROUTER_FRONTEND_WRITE_TIMEOUT", "60s") + pubAddr = getenv("ROUTER_PUBADDR", ":8080") + apiAddr = getenv("ROUTER_APIADDR", ":8081") + postgresURL = getenv("DATABASE_URL", "postgresql://postgres@127.0.0.1:27017/router?sslmode=disable") + postgresDbName = getenv("DATABASE_NAME", "router") + dbPollInterval = getenv("ROUTER_POLL_INTERVAL", "2s") + errorLogFile = getenv("ROUTER_ERROR_LOG", "STDERR") + tlsSkipVerify = os.Getenv("ROUTER_TLS_SKIP_VERIFY") != "" + beConnTimeout = getenvDuration("ROUTER_BACKEND_CONNECT_TIMEOUT", "1s") + beHeaderTimeout = getenvDuration("ROUTER_BACKEND_HEADER_TIMEOUT", "20s") + feReadTimeout = getenvDuration("ROUTER_FRONTEND_READ_TIMEOUT", "60s") + feWriteTimeout = getenvDuration("ROUTER_FRONTEND_WRITE_TIMEOUT", "60s") ) log.Printf("using frontend read timeout: %v", feReadTimeout) @@ -111,9 +111,9 @@ func main() { router.RegisterMetrics(prometheus.DefaultRegisterer) rout, err := router.NewRouter(router.Options{ - MongoURL: mongoURL, - MongoDBName: mongoDBName, - MongoPollInterval: mongoPollInterval, + MongoURL: postgresURL, + MongoDBName: postgresDbName, + MongoPollInterval: dbPollInterval, BackendConnTimeout: beConnTimeout, BackendHeaderTimeout: beHeaderTimeout, LogFileName: errorLogFile,