diff --git a/integration_tests/backend_helpers.go b/integration_tests/backend_helpers.go index fd88d6c6..07b2611d 100644 --- a/integration_tests/backend_helpers.go +++ b/integration_tests/backend_helpers.go @@ -1,6 +1,7 @@ package integration import ( + "net" "net/http" "net/http/httptest" "strconv" @@ -12,14 +13,39 @@ import ( "github.com/onsi/gomega/ghttp" ) -func startSimpleBackend(identifier string) *httptest.Server { - return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { +var backends = map[string]string{ + "backend-1": "127.0.0.1:6789", + "backend-2": "127.0.0.1:6790", + "outer": "127.0.0.1:6792", + "inner": "127.0.0.1:6793", + "innerer": "127.0.0.1:6794", + "root": "127.0.0.1:6795", + "other": "127.0.0.1:6796", + "fallthrough": "127.0.0.1:6797", + "down": "127.0.0.1:6798", + "slow-1": "127.0.0.1:6799", + "slow-2": "127.0.0.1:6800", + "backend": "127.0.0.1:6801", + "be": "127.0.0.1:6802", + "not-running": "127.0.0.1:6803", + "with-path": "127.0.0.1:6804", +} + +func startSimpleBackend(identifier, host string) *httptest.Server { + l, err := net.Listen("tcp", host) + Expect(err).NotTo(HaveOccurred()) + + ts := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { _, err := w.Write([]byte(identifier)) Expect(err).NotTo(HaveOccurred()) })) + ts.Listener.Close() + ts.Listener = l + ts.Start() + return ts } -func startTarpitBackend(delays ...time.Duration) *httptest.Server { +func startTarpitBackend(host string, delays ...time.Duration) *httptest.Server { responseDelay := 2 * time.Second if len(delays) > 0 { responseDelay = delays[0] @@ -28,7 +54,11 @@ func startTarpitBackend(delays ...time.Duration) *httptest.Server { if len(delays) > 1 { bodyDelay = delays[1] } - return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + + l, err := net.Listen("tcp", host) + Expect(err).NotTo(HaveOccurred()) + + ts := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { body := "Tarpit\n" if responseDelay > 0 { @@ -44,24 +74,26 @@ func startTarpitBackend(delays ...time.Duration) *httptest.Server { _, err := w.Write([]byte(body)) Expect(err).NotTo(HaveOccurred()) })) + ts.Listener.Close() + ts.Listener = l + ts.Start() + return ts } -func startRecordingBackend() *ghttp.Server { - return startRecordingServer(false) -} - -func startRecordingTLSBackend() *ghttp.Server { - return startRecordingServer(true) -} +func startRecordingBackend(tls bool, host string) *ghttp.Server { + l, err := net.Listen("tcp", host) + Expect(err).NotTo(HaveOccurred()) -func startRecordingServer(tls bool) (server *ghttp.Server) { + ts := ghttp.NewUnstartedServer() + ts.HTTPTestServer.Listener.Close() + ts.HTTPTestServer.Listener = l if tls { - server = ghttp.NewTLSServer() + ts.HTTPTestServer.StartTLS() } else { - server = ghttp.NewServer() + ts.Start() } - server.AllowUnhandledRequests = true - server.UnhandledRequestStatusCode = http.StatusOK - return server + ts.AllowUnhandledRequests = true + ts.UnhandledRequestStatusCode = http.StatusOK + return ts } diff --git a/integration_tests/integration_test.go b/integration_tests/integration_test.go index a5fcbd34..e716ac91 100644 --- a/integration_tests/integration_test.go +++ b/integration_tests/integration_test.go @@ -20,7 +20,14 @@ var _ = BeforeSuite(func() { if err != nil { Fail(err.Error()) } - err = startRouter(routerPort, apiPort, nil) + + backendEnvVars := []string{} + for id, host := range backends { + envVar := "BACKEND_URL_" + id + "=http://" + host + backendEnvVars = append(backendEnvVars, envVar) + } + + err = startRouter(routerPort, apiPort, backendEnvVars) if err != nil { Fail(err.Error()) } diff --git a/integration_tests/performance_test.go b/integration_tests/performance_test.go index 08be778e..697cb841 100644 --- a/integration_tests/performance_test.go +++ b/integration_tests/performance_test.go @@ -20,10 +20,8 @@ var _ = Describe("Performance", func() { ) BeforeEach(func() { - backend1 = startSimpleBackend("backend 1") - backend2 = startSimpleBackend("backend 2") - addBackend("backend-1", backend1.URL) - addBackend("backend-2", backend2.URL) + backend1 = startSimpleBackend("backend 1", backends["backend-1"]) + backend2 = startSimpleBackend("backend 2", backends["backend-2"]) addRoute("/one", NewBackendRoute("backend-1")) addRoute("/two", NewBackendRoute("backend-2")) reloadRoutes(apiPort) @@ -58,10 +56,9 @@ var _ = Describe("Performance", func() { Describe("with one slow backend hit separately", func() { It("Router should not cause errors or much latency", func() { - slowBackend := startTarpitBackend(time.Second) + slowBackend := startTarpitBackend(backends["slow-1"], time.Second) defer slowBackend.Close() - addBackend("backend-slow", slowBackend.URL) - addRoute("/slow", NewBackendRoute("backend-slow")) + addRoute("/slow", NewBackendRoute("slow-1")) reloadRoutes(apiPort) _, gen := generateLoad([]string{routerURL(routerPort, "/slow")}, 50) @@ -73,8 +70,7 @@ var _ = Describe("Performance", func() { Describe("with one downed backend hit separately", func() { It("Router should not cause errors or much latency", func() { - addBackend("backend-down", "http://127.0.0.1:3162/") - addRoute("/down", NewBackendRoute("backend-down")) + addRoute("/down", NewBackendRoute("down")) reloadRoutes(apiPort) _, gen := generateLoad([]string{routerURL(routerPort, "/down")}, 50) @@ -96,12 +92,10 @@ var _ = Describe("Performance", func() { var backend2 *httptest.Server BeforeEach(func() { - backend1 = startTarpitBackend(time.Second) - backend2 = startTarpitBackend(time.Second) - addBackend("backend-1", backend1.URL) - addBackend("backend-2", backend2.URL) - addRoute("/one", NewBackendRoute("backend-1")) - addRoute("/two", NewBackendRoute("backend-2")) + backend1 = startTarpitBackend(backends["slow-1"], time.Second) + backend2 = startTarpitBackend(backends["slow-2"], time.Second) + addRoute("/one", NewBackendRoute("slow-1")) + addRoute("/two", NewBackendRoute("slow-2")) reloadRoutes(apiPort) }) AfterEach(func() { diff --git a/integration_tests/proxy_function_test.go b/integration_tests/proxy_function_test.go index cb383137..127d0321 100644 --- a/integration_tests/proxy_function_test.go +++ b/integration_tests/proxy_function_test.go @@ -19,7 +19,6 @@ var _ = Describe("Functioning as a reverse proxy", func() { Describe("connecting to the backend", func() { It("should return a 502 if the connection to the backend is refused", func() { - addBackend("not-running", "http://127.0.0.1:3164/") addRoute("/not-running", NewBackendRoute("not-running")) reloadRoutes(apiPort) @@ -31,21 +30,20 @@ var _ = Describe("Functioning as a reverse proxy", func() { logDetails := lastRouterErrorLogEntry() Expect(logDetails.Fields).To(Equal(map[string]interface{}{ - "error": "dial tcp 127.0.0.1:3164: connect: connection refused", + "error": "dial tcp 127.0.0.1:6803: connect: connection refused", "request": "GET /not-running HTTP/1.1", "request_method": "GET", "status": float64(502), // All numbers in JSON are floating point - "upstream_addr": "127.0.0.1:3164", + "upstream_addr": "127.0.0.1:6803", })) Expect(logDetails.Timestamp).To(BeTemporally("~", time.Now(), time.Second)) }) It("should log and return a 504 if the connection times out in the configured time", func() { - err := startRouter(3167, 3166, []string{"ROUTER_BACKEND_CONNECT_TIMEOUT=0.3s"}) + err := startRouter(3167, 3166, []string{"ROUTER_BACKEND_CONNECT_TIMEOUT=0.3s", "BACKEND_URL_black-hole=http://240.0.0.0:1234/"}) Expect(err).NotTo(HaveOccurred()) defer stopRouter(3167) - addBackend("black-hole", "http://240.0.0.0:1234/") addRoute("/should-time-out", NewBackendRoute("black-hole")) reloadRoutes(3166) @@ -74,14 +72,12 @@ var _ = Describe("Functioning as a reverse proxy", func() { var tarpit1, tarpit2 *httptest.Server BeforeEach(func() { - err := startRouter(3167, 3166, []string{"ROUTER_BACKEND_HEADER_TIMEOUT=0.3s"}) + err := startRouter(3167, 3166, []string{"ROUTER_BACKEND_HEADER_TIMEOUT=0.3s", "BACKEND_URL_slow-1=http://127.0.0.1:6256/", "BACKEND_URL_slow-2=http://127.0.0.1:6253/"}) Expect(err).NotTo(HaveOccurred()) - tarpit1 = startTarpitBackend(time.Second) - tarpit2 = startTarpitBackend(100*time.Millisecond, 500*time.Millisecond) - addBackend("tarpit1", tarpit1.URL) - addBackend("tarpit2", tarpit2.URL) - addRoute("/tarpit1", NewBackendRoute("tarpit1")) - addRoute("/tarpit2", NewBackendRoute("tarpit2")) + tarpit1 = startTarpitBackend("127.0.0.1:6256", time.Second) + tarpit2 = startTarpitBackend("127.0.0.1:6253", 100*time.Millisecond, 500*time.Millisecond) + addRoute("/tarpit1", NewBackendRoute("slow-1")) + addRoute("/tarpit2", NewBackendRoute("slow-2")) reloadRoutes(3166) }) @@ -118,8 +114,7 @@ var _ = Describe("Functioning as a reverse proxy", func() { Describe("header handling", func() { BeforeEach(func() { - recorder = startRecordingBackend() - addBackend("backend", recorder.URL()) + recorder = startRecordingBackend(false, backends["backend"]) addRoute("/foo", NewBackendRoute("backend", "prefix")) reloadRoutes(apiPort) }) @@ -242,8 +237,7 @@ var _ = Describe("Functioning as a reverse proxy", func() { Describe("request verb, path, query and body handling", func() { BeforeEach(func() { - recorder = startRecordingBackend() - addBackend("backend", recorder.URL()) + recorder = startRecordingBackend(false, backends["backend"]) addRoute("/foo", NewBackendRoute("backend", "prefix")) reloadRoutes(apiPort) }) @@ -298,18 +292,20 @@ var _ = Describe("Functioning as a reverse proxy", func() { Describe("handling a backend with a non '/' path", func() { BeforeEach(func() { - recorder = startRecordingBackend() - addBackend("backend", recorder.URL()+"/something") - addRoute("/foo/bar", NewBackendRoute("backend", "prefix")) - reloadRoutes(apiPort) + err := startRouter(3167, 3166, []string{"ROUTER_TLS_SKIP_VERIFY=1", "BACKEND_URL_with-path=http://127.0.0.1:6804/something"}) + Expect(err).NotTo(HaveOccurred()) + recorder = startRecordingBackend(false, backends["with-path"]) + addRoute("/foo/bar", NewBackendRoute("with-path", "prefix")) + reloadRoutes(3166) }) AfterEach(func() { recorder.Close() + stopRouter(3167) }) It("should merge the 2 paths", func() { - resp := routerRequest(routerPort, "/foo/bar") + resp := routerRequest(3167, "/foo/bar") Expect(resp.StatusCode).To(Equal(200)) Expect(recorder.ReceivedRequests()).To(HaveLen(1)) @@ -318,7 +314,7 @@ var _ = Describe("Functioning as a reverse proxy", func() { }) It("should preserve the request query string", func() { - resp := routerRequest(routerPort, "/foo/bar?baz=qux") + resp := routerRequest(3167, "/foo/bar?baz=qux") Expect(resp.StatusCode).To(Equal(200)) Expect(recorder.ReceivedRequests()).To(HaveLen(1)) @@ -329,8 +325,7 @@ var _ = Describe("Functioning as a reverse proxy", func() { Describe("handling HTTP/1.0 requests", func() { BeforeEach(func() { - recorder = startRecordingBackend() - addBackend("backend", recorder.URL()) + recorder = startRecordingBackend(false, backends["backend"]) addRoute("/foo", NewBackendRoute("backend", "prefix")) reloadRoutes(apiPort) }) @@ -362,10 +357,9 @@ var _ = Describe("Functioning as a reverse proxy", func() { Describe("handling requests to a HTTPS backend", func() { BeforeEach(func() { - err := startRouter(3167, 3166, []string{"ROUTER_TLS_SKIP_VERIFY=1"}) + err := startRouter(3167, 3166, []string{"ROUTER_TLS_SKIP_VERIFY=1", "BACKEND_URL_backend=https://127.0.0.1:2486"}) Expect(err).NotTo(HaveOccurred()) - recorder = startRecordingTLSBackend() - addBackend("backend", recorder.URL()) + recorder = startRecordingBackend(true, "127.0.0.1:2486") addRoute("/foo", NewBackendRoute("backend", "prefix")) reloadRoutes(3166) }) diff --git a/integration_tests/redirect_test.go b/integration_tests/redirect_test.go index 0179d4ac..14089021 100644 --- a/integration_tests/redirect_test.go +++ b/integration_tests/redirect_test.go @@ -222,8 +222,7 @@ var _ = Describe("Redirection", func() { var recorder *ghttp.Server BeforeEach(func() { - recorder = startRecordingBackend() - addBackend("be", recorder.URL()) + recorder = startRecordingBackend(false, backends["be"]) addRoute("/guidance/keeping-a-pet-pig-or-micropig", NewBackendRoute("be", "exact")) addRoute("/GUIDANCE/keeping-a-pet-pig-or-micropig", NewBackendRoute("be", "exact")) reloadRoutes(apiPort) diff --git a/integration_tests/route_helpers.go b/integration_tests/route_helpers.go index 84ed1e53..4d2795b8 100644 --- a/integration_tests/route_helpers.go +++ b/integration_tests/route_helpers.go @@ -6,7 +6,6 @@ import ( "time" "github.com/globalsign/mgo" - "github.com/globalsign/mgo/bson" // revive:disable:dot-imports . "github.com/onsi/ginkgo/v2" @@ -91,11 +90,6 @@ func initRouteHelper() error { return nil } -func addBackend(id, url string) { - err := routerDB.C("backends").Insert(bson.M{"backend_id": id, "backend_url": url}) - Expect(err).NotTo(HaveOccurred()) -} - func addRoute(path string, route Route) { route.IncomingPath = path diff --git a/integration_tests/route_loading_test.go b/integration_tests/route_loading_test.go index 827c1deb..57a1154d 100644 --- a/integration_tests/route_loading_test.go +++ b/integration_tests/route_loading_test.go @@ -1,7 +1,6 @@ package integration import ( - "fmt" "net/http/httptest" . "github.com/onsi/ginkgo/v2" @@ -12,14 +11,11 @@ var _ = Describe("loading routes from the db", func() { var ( backend1 *httptest.Server backend2 *httptest.Server - backend3 *httptest.Server ) BeforeEach(func() { - backend1 = startSimpleBackend("backend 1") - backend2 = startSimpleBackend("backend 2") - addBackend("backend-1", backend1.URL) - addBackend("backend-2", backend2.URL) + backend1 = startSimpleBackend("backend 1", backends["backend-1"]) + backend2 = startSimpleBackend("backend 2", backends["backend-2"]) }) AfterEach(func() { backend1.Close() @@ -73,34 +69,4 @@ var _ = Describe("loading routes from the db", func() { Expect(readBody(resp)).To(Equal("backend 1")) }) }) - - Context("a backend an env var overriding the backend_url", func() { - BeforeEach(func() { - // This tests the behaviour of backend.ParseURL overriding the backend_url - // provided in the DB with the value of an env var - blackHole := "240.0.0.0/foo" - backend3 = startSimpleBackend("backend 3") - addBackend("backend-3", blackHole) - - stopRouter(routerPort) - err := startRouter(routerPort, apiPort, []string{fmt.Sprintf("BACKEND_URL_backend-3=%s", backend3.URL)}) - Expect(err).NotTo(HaveOccurred()) - - addRoute("/oof", NewBackendRoute("backend-3")) - reloadRoutes(apiPort) - }) - - AfterEach(func() { - stopRouter(routerPort) - err := startRouter(routerPort, apiPort, nil) - Expect(err).NotTo(HaveOccurred()) - backend3.Close() - }) - - It("should send requests to the backend_url provided in the env var", func() { - resp := routerRequest(routerPort, "/oof") - Expect(resp.StatusCode).To(Equal(200)) - Expect(readBody(resp)).To(Equal("backend 3")) - }) - }) }) diff --git a/integration_tests/route_selection_test.go b/integration_tests/route_selection_test.go index 4b60e44b..dacaf226 100644 --- a/integration_tests/route_selection_test.go +++ b/integration_tests/route_selection_test.go @@ -17,10 +17,8 @@ var _ = Describe("Route selection", func() { ) BeforeEach(func() { - backend1 = startSimpleBackend("backend 1") - backend2 = startSimpleBackend("backend 2") - addBackend("backend-1", backend1.URL) - addBackend("backend-2", backend2.URL) + backend1 = startSimpleBackend("backend 1", backends["backend-1"]) + backend2 = startSimpleBackend("backend 2", backends["backend-2"]) addRoute("/foo", NewBackendRoute("backend-1")) addRoute("/bar", NewBackendRoute("backend-2")) addRoute("/baz", NewBackendRoute("backend-1")) @@ -66,10 +64,8 @@ var _ = Describe("Route selection", func() { ) BeforeEach(func() { - backend1 = startSimpleBackend("backend 1") - backend2 = startSimpleBackend("backend 2") - addBackend("backend-1", backend1.URL) - addBackend("backend-2", backend2.URL) + backend1 = startSimpleBackend("backend 1", backends["backend-1"]) + backend2 = startSimpleBackend("backend 2", backends["backend-2"]) addRoute("/foo", NewBackendRoute("backend-1", "prefix")) addRoute("/bar", NewBackendRoute("backend-2", "prefix")) addRoute("/baz", NewBackendRoute("backend-1", "prefix")) @@ -121,11 +117,9 @@ var _ = Describe("Route selection", func() { ) BeforeEach(func() { - outer = startSimpleBackend("outer") - inner = startSimpleBackend("inner") - addBackend("outer-backend", outer.URL) - addBackend("inner-backend", inner.URL) - addRoute("/foo", NewBackendRoute("outer-backend", "prefix")) + outer = startSimpleBackend("outer", backends["outer"]) + inner = startSimpleBackend("inner", backends["inner"]) + addRoute("/foo", NewBackendRoute("outer", "prefix")) reloadRoutes(apiPort) }) AfterEach(func() { @@ -135,7 +129,7 @@ var _ = Describe("Route selection", func() { Describe("with an exact child", func() { BeforeEach(func() { - addRoute("/foo/bar", NewBackendRoute("inner-backend")) + addRoute("/foo/bar", NewBackendRoute("inner")) reloadRoutes(apiPort) }) @@ -157,7 +151,7 @@ var _ = Describe("Route selection", func() { Describe("with a prefix child", func() { BeforeEach(func() { - addRoute("/foo/bar", NewBackendRoute("inner-backend", "prefix")) + addRoute("/foo/bar", NewBackendRoute("inner", "prefix")) reloadRoutes(apiPort) }) @@ -190,10 +184,9 @@ var _ = Describe("Route selection", func() { innerer *httptest.Server ) BeforeEach(func() { - innerer = startSimpleBackend("innerer") - addBackend("innerer-backend", innerer.URL) - addRoute("/foo/bar", NewBackendRoute("inner-backend")) - addRoute("/foo/bar/baz", NewBackendRoute("innerer-backend", "prefix")) + innerer = startSimpleBackend("innerer", backends["innerer"]) + addRoute("/foo/bar", NewBackendRoute("inner")) + addRoute("/foo/bar/baz", NewBackendRoute("innerer", "prefix")) reloadRoutes(apiPort) }) AfterEach(func() { @@ -243,10 +236,8 @@ var _ = Describe("Route selection", func() { ) BeforeEach(func() { - backend1 = startSimpleBackend("backend 1") - backend2 = startSimpleBackend("backend 2") - addBackend("backend-1", backend1.URL) - addBackend("backend-2", backend2.URL) + backend1 = startSimpleBackend("backend 1", backends["backend-1"]) + backend2 = startSimpleBackend("backend 2", backends["backend-2"]) addRoute("/foo", NewBackendRoute("backend-1", "prefix")) addRoute("/foo", NewBackendRoute("backend-2")) reloadRoutes(apiPort) @@ -274,10 +265,8 @@ var _ = Describe("Route selection", func() { ) BeforeEach(func() { - root = startSimpleBackend("root backend") - other = startSimpleBackend("other backend") - addBackend("root", root.URL) - addBackend("other", other.URL) + root = startSimpleBackend("root backend", backends["root"]) + other = startSimpleBackend("other backend", backends["other"]) addRoute("/foo", NewBackendRoute("other")) }) AfterEach(func() { @@ -321,11 +310,9 @@ var _ = Describe("Route selection", func() { ) BeforeEach(func() { - root = startSimpleBackend("fallthrough") - recorder = startRecordingBackend() - addBackend("root", root.URL) - addBackend("other", recorder.URL()) - addRoute("/", NewBackendRoute("root", "prefix")) + root = startSimpleBackend("fallthrough", backends["fallthrough"]) + recorder = startRecordingBackend(false, backends["other"]) + addRoute("/", NewBackendRoute("fallthrough", "prefix")) addRoute("/foo/bar", NewBackendRoute("other", "prefix")) reloadRoutes(apiPort) }) @@ -358,8 +345,7 @@ var _ = Describe("Route selection", func() { var recorder *ghttp.Server BeforeEach(func() { - recorder = startRecordingBackend() - addBackend("backend", recorder.URL()) + recorder = startRecordingBackend(false, backends["backend"]) }) AfterEach(func() { recorder.Close() diff --git a/lib/backends.go b/lib/backends.go new file mode 100644 index 00000000..85c04c43 --- /dev/null +++ b/lib/backends.go @@ -0,0 +1,49 @@ +package router + +import ( + "fmt" + "net/http" + "net/url" + "os" + "strings" + "time" + + "github.com/alphagov/router/handlers" + "github.com/alphagov/router/logger" +) + +func loadBackendsFromEnv(connTimeout, headerTimeout time.Duration, logger logger.Logger) (backends map[string]http.Handler) { + backends = make(map[string]http.Handler) + + for _, envvar := range os.Environ() { + pair := strings.SplitN(envvar, "=", 2) + + if !strings.HasPrefix(pair[0], "BACKEND_URL_") { + continue + } + + backendID := strings.TrimPrefix(pair[0], "BACKEND_URL_") + backendURL := pair[1] + + if backendURL == "" { + logWarn(fmt.Errorf("router: couldn't find URL for backend %s, skipping", backendID)) + continue + } + + backend, err := url.Parse(backendURL) + if err != nil { + logWarn(fmt.Errorf("router: couldn't parse URL %s for backend %s (error: %w), skipping", backendURL, backendID, err)) + continue + } + + backends[backendID] = handlers.NewBackendHandler( + backendID, + backend, + connTimeout, + headerTimeout, + logger, + ) + } + + return +} diff --git a/lib/backends_test.go b/lib/backends_test.go new file mode 100644 index 00000000..5d1f5748 --- /dev/null +++ b/lib/backends_test.go @@ -0,0 +1,41 @@ +package router + +import ( + "os" + "time" + + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" +) + +var _ = Describe("Backends", func() { + Context("When calling loadBackendsFromEnv", func() { + It("should load backends from environment variables", func() { + os.Setenv("BACKEND_URL_testBackend", "http://example.com") + defer os.Unsetenv("BACKEND_URL_testBackend") + + backends := loadBackendsFromEnv(1*time.Second, 20*time.Second, nil) + + Expect(backends).To(HaveKey("testBackend")) + Expect(backends["testBackend"]).ToNot(BeNil()) + }) + + It("should skip backends with empty URLs", func() { + os.Setenv("BACKEND_URL_emptyBackend", "") + defer os.Unsetenv("BACKEND_URL_emptyBackend") + + backends := loadBackendsFromEnv(1*time.Second, 20*time.Second, nil) + + Expect(backends).ToNot(HaveKey("emptyBackend")) + }) + + It("should skip backends with invalid URLs", func() { + os.Setenv("BACKEND_URL_invalidBackend", "://invalid-url") + defer os.Unsetenv("BACKEND_URL_invalidBackend") + + backends := loadBackendsFromEnv(1*time.Second, 20*time.Second, nil) + + Expect(backends).ToNot(HaveKey("invalidBackend")) + }) + }) +}) diff --git a/lib/router.go b/lib/router.go index 5fbac7bd..d7890ce6 100644 --- a/lib/router.go +++ b/lib/router.go @@ -4,7 +4,6 @@ import ( "fmt" "net/http" "net/url" - "os" "strconv" "sync" "time" @@ -39,6 +38,7 @@ const ( // come from, Route and Backend should not contain bson fields. // MongoReplicaSet, MongoReplicaSetMember etc. should move out of this module. type Router struct { + backends map[string]http.Handler mux *triemux.Mux lock sync.RWMutex mongoReadToOptime bson.MongoTimestamp @@ -106,8 +106,11 @@ func NewRouter(o Options) (rt *Router, err error) { return nil, err } + backends := loadBackendsFromEnv(o.BackendConnTimeout, o.BackendHeaderTimeout, l) + reloadChan := make(chan bool, 1) rt = &Router{ + backends: backends, mux: triemux.NewMux(), mongoReadToOptime: mongoReadToOptime, logger: l, @@ -235,8 +238,7 @@ func (rt *Router) reloadRoutes(db *mgo.Database, currentOptime bson.MongoTimesta logInfo("router: reloading routes") newmux := triemux.NewMux() - backends := rt.loadBackends(db.C("backends")) - loadRoutes(db.C("routes"), newmux, backends) + loadRoutes(db.C("routes"), newmux, rt.backends) routeCount := newmux.RouteCount() rt.lock.Lock() @@ -286,39 +288,6 @@ func (rt *Router) shouldReload(currentMongoInstance MongoReplicaSetMember) bool return currentMongoInstance.Optime > rt.mongoReadToOptime } -// loadBackends is a helper function which loads backends from the -// passed mongo collection, constructs a Handler for each one, and returns -// them in map keyed on the backend_id -func (rt *Router) loadBackends(c *mgo.Collection) (backends map[string]http.Handler) { - backend := &Backend{} - backends = make(map[string]http.Handler) - - iter := c.Find(nil).Iter() - - for iter.Next(&backend) { - backendURL, err := backend.ParseURL() - if err != nil { - logWarn(fmt.Errorf("router: couldn't parse URL %s for backend %s "+ - "(error: %w), skipping", backend.BackendURL, backend.BackendID, err)) - continue - } - - backends[backend.BackendID] = handlers.NewBackendHandler( - backend.BackendID, - backendURL, - rt.opts.BackendConnTimeout, - rt.opts.BackendHeaderTimeout, - rt.logger, - ) - } - - if err := iter.Err(); err != nil { - panic(err) - } - - return -} - // loadRoutes is a helper function which loads routes from the passed mongo // collection and registers them with the passed proxy mux. func loadRoutes(c *mgo.Collection, mux *triemux.Mux, backends map[string]http.Handler) { @@ -378,14 +347,6 @@ func loadRoutes(c *mgo.Collection, mux *triemux.Mux, backends map[string]http.Ha } } -func (be *Backend) ParseURL() (*url.URL, error) { - backendURL := os.Getenv(fmt.Sprintf("BACKEND_URL_%s", be.BackendID)) - if backendURL == "" { - backendURL = be.BackendURL - } - return url.Parse(backendURL) -} - func shouldPreserveSegments(route *Route) bool { switch route.RouteType { case RouteTypeExact: