From 8927a0a34dcfc326d8bab430743e7d4ac9bb9057 Mon Sep 17 00:00:00 2001 From: dadav <33197631+dadav@users.noreply.github.com> Date: Mon, 26 Feb 2024 22:53:24 +0100 Subject: [PATCH] feat: Add more code --- cmd/serve.go | 38 ++++++++---- go.mod | 1 + go.sum | 2 + internal/api/v3/module.go | 15 +++-- internal/backend/filesystem.go | 10 ++- internal/config/config.go | 14 +++-- internal/middleware/proxy.go | 107 +++++++++++++++++++++++++++++++++ 7 files changed, 157 insertions(+), 30 deletions(-) create mode 100644 internal/middleware/proxy.go diff --git a/cmd/serve.go b/cmd/serve.go index 7050f07..63edd58 100644 --- a/cmd/serve.go +++ b/cmd/serve.go @@ -18,14 +18,17 @@ package cmd import ( "fmt" "net/http" + "strings" v3 "github.com/dadav/gorge/internal/api/v3" backend "github.com/dadav/gorge/internal/backend" config "github.com/dadav/gorge/internal/config" log "github.com/dadav/gorge/internal/log" - middleware "github.com/dadav/gorge/internal/middleware" + customMiddleware "github.com/dadav/gorge/internal/middleware" openapi "github.com/dadav/gorge/pkg/gen/v3/openapi" "github.com/go-chi/chi/v5" + "github.com/go-chi/chi/v5/middleware" + "github.com/go-chi/cors" "github.com/spf13/cobra" ) @@ -57,15 +60,31 @@ to quickly create a Cobra application.`, userService := v3.NewUserOperationsApi() r := chi.NewRouter() - r.Use(middleware.RequireUserAgent) - handler := openapi.NewRouter( + r.Use(customMiddleware.RequireUserAgent) + + if config.FallbackProxyUrl != "" { + r.Use(customMiddleware.ProxyFallback(config.FallbackProxyUrl)) + } + + r.Use(middleware.Recoverer) + r.Use(middleware.RealIP) + + r.Use(cors.Handler(cors.Options{ + AllowedOrigins: strings.Split(config.CORSOrigins, ","), + AllowedMethods: []string{"GET", "POST", "PUT", "DELETE", "OPTIONS"}, + AllowedHeaders: []string{"Accept", "Content-Type"}, + AllowCredentials: false, + MaxAge: 300, + })) + + apiRouter := openapi.NewRouter( openapi.NewModuleOperationsAPIController(moduleService), openapi.NewReleaseOperationsAPIController(releaseService), openapi.NewSearchFilterOperationsAPIController(searchFilterService), openapi.NewUserOperationsAPIController(userService), ) - r.Mount("/", handler) + r.Mount("/", apiRouter) log.Log.Infof("Listen on %s:%d", config.Bind, config.Port) log.Log.Panic(http.ListenAndServe(fmt.Sprintf("%s:%d", config.Bind, config.Port), r)) @@ -78,19 +97,12 @@ to quickly create a Cobra application.`, func init() { rootCmd.AddCommand(serveCmd) - // Here you will define your flags and configuration settings. - - // Cobra supports Persistent Flags which will work for this command - // and all subcommands, e.g.: - // serveCmd.PersistentFlags().String("foo", "", "A help for foo") - - // Cobra supports local flags which will only run when this command - // is called directly, e.g.: - // serveCmd.Flags().BoolP("toggle", "t", false, "Help message for toggle") serveCmd.Flags().StringVar(&config.ApiVersion, "api-version", "v3", "the forge api version to use") serveCmd.Flags().IntVar(&config.Port, "port", 8080, "the port to listen to") serveCmd.Flags().StringVar(&config.Bind, "bind", "", "host to listen to") serveCmd.Flags().StringVar(&config.ModulesDir, "modulesdir", "/opt/gorge/modules", "directory containing all the modules") serveCmd.Flags().StringVar(&config.Backend, "backend", "filesystem", "backend to use") + serveCmd.Flags().StringVar(&config.CORSOrigins, "cors", "*", "allowed cors origins separated by comma") + serveCmd.Flags().StringVar(&config.FallbackProxyUrl, "fallback-proxy", "", "optional fallback upstream proxy url") serveCmd.Flags().BoolVar(&config.Dev, "dev", false, "enables dev mode") } diff --git a/go.mod b/go.mod index ddce907..9d6f1fd 100644 --- a/go.mod +++ b/go.mod @@ -4,6 +4,7 @@ go 1.22.0 require ( github.com/go-chi/chi/v5 v5.0.12 + github.com/go-chi/cors v1.2.1 github.com/mitchellh/go-homedir v1.1.0 github.com/spf13/cobra v1.8.0 github.com/spf13/viper v1.18.2 diff --git a/go.sum b/go.sum index bb18b35..37d5e6e 100644 --- a/go.sum +++ b/go.sum @@ -9,6 +9,8 @@ github.com/fsnotify/fsnotify v1.7.0 h1:8JEhPFa5W2WU7YfeZzPNqzMP6Lwt7L2715Ggo0nos github.com/fsnotify/fsnotify v1.7.0/go.mod h1:40Bi/Hjc2AVfZrqy+aj+yEI+/bRxZnMJyTJwOpGvigM= github.com/go-chi/chi/v5 v5.0.12 h1:9euLV5sTrTNTRUU9POmDUvfxyj6LAABLUcEWO+JJb4s= github.com/go-chi/chi/v5 v5.0.12/go.mod h1:DslCQbL2OYiznFReuXYUmQ2hGd1aDpCnlMNITLSKoi8= +github.com/go-chi/cors v1.2.1 h1:xEC8UT3Rlp2QuWNEr4Fs/c2EAGVKBwy/1vHx3bppil4= +github.com/go-chi/cors v1.2.1/go.mod h1:sSbTewc+6wYHBBCW7ytsFSn836hqM7JxpglAy2Vzc58= github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38= github.com/google/go-cmp v0.5.9/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= github.com/hashicorp/hcl v1.0.0 h1:0Anlzjpi4vEasTeNFn2mLJgTSwt0+6sfsiTG8qcWGx4= diff --git a/internal/api/v3/module.go b/internal/api/v3/module.go index d5d1546..e8940cd 100644 --- a/internal/api/v3/module.go +++ b/internal/api/v3/module.go @@ -9,7 +9,6 @@ import ( "strings" "github.com/dadav/gorge/internal/backend" - "github.com/dadav/gorge/internal/log" gen "github.com/dadav/gorge/pkg/gen/v3/openapi" ) @@ -67,7 +66,7 @@ func (s *ModuleOperationsApi) DeprecateModule(ctx context.Context, moduleSlug st return gen.Response(http.StatusNotImplemented, nil), errors.New("DeprecateModule method not implemented") } -type GetModule500Response struct { +type GetModule404Response struct { Message string `json:"message,omitempty"` Errors []string `json:"errors,omitempty"` } @@ -76,13 +75,13 @@ type GetModule500Response struct { func (s *ModuleOperationsApi) GetModule(ctx context.Context, moduleSlug string, withHtml bool, includeFields []string, excludeFields []string, ifModifiedSince string) (gen.ImplResponse, error) { module, err := backend.ConfiguredBackend.GetModuleBySlug(moduleSlug) if err != nil { - log.Log.Error(err) + // log.Log.Error(err) return gen.Response( - http.StatusInternalServerError, - GetModule500Response{ - Message: http.StatusText(http.StatusInternalServerError), - Errors: []string{"There was some error while reading the metadata"}, - }), err + http.StatusNotFound, + GetModule404Response{ + Message: http.StatusText(http.StatusNotFound), + Errors: []string{"Module could not be found"}, + }), nil } return gen.Response(http.StatusOK, module), nil diff --git a/internal/backend/filesystem.go b/internal/backend/filesystem.go index 3767de3..42da8fc 100644 --- a/internal/backend/filesystem.go +++ b/internal/backend/filesystem.go @@ -22,10 +22,10 @@ import ( ) type FilesystemBackend struct { - muModules sync.Mutex + muModules sync.RWMutex Modules map[string]*gen.Module ModulesDir string - muReleases sync.Mutex + muReleases sync.RWMutex Releases map[string][]*gen.Release } @@ -93,7 +93,11 @@ func (s *FilesystemBackend) GetAllModules() []*gen.Module { func (s *FilesystemBackend) GetModuleBySlug(slug string) (*gen.Module, error) { s.muModules.Lock() defer s.muModules.Unlock() - return s.Modules[slug], nil + if module, ok := s.Modules[slug]; !ok { + return nil, errors.New("module not found") + } else { + return module, nil + } } func (s *FilesystemBackend) GetReleaseBySlug(slug string) (*gen.Release, error) { diff --git a/internal/config/config.go b/internal/config/config.go index ac91117..b5f34ce 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -1,10 +1,12 @@ package config var ( - ApiVersion string - Port int - Bind string - Dev bool - ModulesDir string - Backend string + ApiVersion string + Port int + Bind string + Dev bool + ModulesDir string + Backend string + CORSOrigins string + FallbackProxyUrl string ) diff --git a/internal/middleware/proxy.go b/internal/middleware/proxy.go new file mode 100644 index 0000000..f6c9097 --- /dev/null +++ b/internal/middleware/proxy.go @@ -0,0 +1,107 @@ +package middleware + +import ( + "bytes" + "io" + "net/http" + "net/url" + + "github.com/dadav/gorge/internal/log" +) + +// capturedResponseWriter is a custom response writer that captures the response status +type capturedResponseWriter struct { + http.ResponseWriter + body []byte + status int +} + +func (w *capturedResponseWriter) WriteHeader(code int) { + w.status = code +} + +func (w *capturedResponseWriter) Write(body []byte) (int, error) { + w.body = body + return len(body), nil +} + +func (w *capturedResponseWriter) sendOriginalResponse() { + w.ResponseWriter.WriteHeader(w.status) + w.ResponseWriter.Write(w.body) +} + +func ProxyFallback(upstreamHost string) func(next http.Handler) http.Handler { + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // capture response + capturedResponseWriter := &capturedResponseWriter{ResponseWriter: w} + next.ServeHTTP(capturedResponseWriter, r) + + if capturedResponseWriter.status == http.StatusNotFound { + log.Log.Infof("Forwarding request to %s\n", upstreamHost) + forwardRequest(w, r, upstreamHost) + return + } + + // If the response status is not 404, serve the original response + capturedResponseWriter.sendOriginalResponse() + }) + } +} + +func forwardRequest(w http.ResponseWriter, r *http.Request, forwardHost string) { + // Create a buffer to store the request body + var requestBodyBytes []byte + if r.Body != nil { + requestBodyBytes, _ = io.ReadAll(r.Body) + } + + // Clone the original request + forwardUrl, err := url.JoinPath(forwardHost, r.URL.Path) + if err != nil { + http.Error(w, "Failed to create forwarded request", http.StatusInternalServerError) + return + } + + forwardedRequest, err := http.NewRequest(r.Method, forwardUrl, bytes.NewBuffer(requestBodyBytes)) + if err != nil { + http.Error(w, "Failed to create forwarded request", http.StatusInternalServerError) + return + } + + // Copy headers from the original request + forwardedRequest.Header = make(http.Header) + for key, values := range r.Header { + for _, value := range values { + forwardedRequest.Header.Add(key, value) + } + } + + // Make the request to the forward host + // TODO: Add caching + client := http.Client{} + resp, err := client.Do(forwardedRequest) + if err != nil { + http.Error(w, "Failed to forward request", http.StatusInternalServerError) + return + } + defer resp.Body.Close() + + // Copy the response headers + for key, values := range resp.Header { + for _, value := range values { + w.Header().Add(key, value) + } + } + + // Write the response status code + w.WriteHeader(resp.StatusCode) + + // Write the response body + body, err := io.ReadAll(resp.Body) + if err != nil { + http.Error(w, "Failed to read response body", http.StatusInternalServerError) + return + } + w.Write(body) +}