Skip to content

Commit

Permalink
fix: Improve code in various places
Browse files Browse the repository at this point in the history
  • Loading branch information
dadav committed Dec 25, 2024
1 parent fd20126 commit e0809a4
Show file tree
Hide file tree
Showing 15 changed files with 521 additions and 207 deletions.
22 changes: 20 additions & 2 deletions Dockerfile
Original file line number Diff line number Diff line change
@@ -1,9 +1,27 @@
FROM alpine:3.21@sha256:21dc6063fd678b478f57c0e13f47560d0ea4eeba26dfc947b2a4f81f686b9f45

# Create non-root user and set up permissions in a single layer
RUN adduser -k /dev/null -u 10001 -D gorge \
&& chgrp 0 /home/gorge \
&& chmod -R g+rwX /home/gorge
COPY gorge /
&& chmod -R g+rwX /home/gorge \
# Add additional security hardening
&& chmod 755 /gorge

# Copy application binary with explicit permissions
COPY --chmod=755 gorge /

# Set working directory
WORKDIR /home/gorge

# Switch to non-root user
USER 10001

# Define volume
VOLUME [ "/home/gorge" ]

# Set health check
HEALTHCHECK --interval=30s --timeout=3s \
CMD curl -f http://localhost:8080/readyz || exit 1

ENTRYPOINT ["/gorge"]
CMD [ "serve" ]
3 changes: 0 additions & 3 deletions cmd/config.go

This file was deleted.

41 changes: 24 additions & 17 deletions cmd/root.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ package cmd

import (
"fmt"
"os"
"log"
"path/filepath"
"strings"

Expand Down Expand Up @@ -47,8 +47,7 @@ var rootCmd = &cobra.Command{
// This is called by main.main(). It only needs to happen once to the rootCmd.
func Execute() {
if err := rootCmd.Execute(); err != nil {
fmt.Println(err)
os.Exit(1)
log.Fatalf("Error executing root command: %v", err)
}
}

Expand All @@ -64,11 +63,9 @@ func initConfig(cmd *cobra.Command) error {
// Use config file from the flag.
v.SetConfigFile(cfgFile)
} else {
// Find home directory.
home, err := homedir.Dir()
if err != nil {
fmt.Println(err)
os.Exit(1)
return fmt.Errorf("failed to get home directory: %w", err)
}

homeConfig := filepath.Join(home, ".config")
Expand All @@ -81,27 +78,37 @@ func initConfig(cmd *cobra.Command) error {
v.SetEnvPrefix(envPrefix)
}

v.AutomaticEnv() // read in environment variables that match
v.AutomaticEnv()

// If a config file is found, read it in.
if err := v.ReadInConfig(); err == nil {
fmt.Println("Using config file:", v.ConfigFileUsed())
if err := v.ReadInConfig(); err != nil {
if _, ok := err.(viper.ConfigFileNotFoundError); !ok {
// Only return an error if it's not a missing config file
return fmt.Errorf("failed to read config file: %w", err)
}
} else {
log.Printf("Using config file: %s", v.ConfigFileUsed())
}

bindFlags(cmd, v)

return nil
return bindFlags(cmd, v)
}

func bindFlags(cmd *cobra.Command, v *viper.Viper) {
// bindFlags binds cobra flags with viper config
func bindFlags(cmd *cobra.Command, v *viper.Viper) error {
var bindingErrors []string

cmd.Flags().VisitAll(func(f *pflag.Flag) {
// Determine the naming convention of the flags when represented in the config file
configName := f.Name

// Apply the viper config value to the flag when the flag is not set and viper has a value
if !f.Changed && v.IsSet(configName) {
val := v.Get(configName)
cmd.Flags().Set(f.Name, fmt.Sprintf("%v", val))
if err := cmd.Flags().Set(f.Name, fmt.Sprintf("%v", val)); err != nil {
bindingErrors = append(bindingErrors, fmt.Sprintf("failed to bind flag %s: %v", f.Name, err))
}
}
})

if len(bindingErrors) > 0 {
return fmt.Errorf("flag binding errors: %s", strings.Join(bindingErrors, "; "))
}
return nil
}
141 changes: 86 additions & 55 deletions cmd/serve.go
Original file line number Diff line number Diff line change
Expand Up @@ -125,27 +125,22 @@ You can also enable the caching functionality to speed things up.`,
userService := v3.NewUserOperationsApi()

r := chi.NewRouter()

// 1. Recoverer should be first to catch panics in all other middleware
r.Use(middleware.Recoverer)
// 2. RealIP should be early to ensure all other middleware sees the correct IP
r.Use(middleware.RealIP)
r.Use(customMiddleware.RequireUserAgent)
x := customMiddleware.NewStatistics()
r.Use(customMiddleware.StatisticsMiddleware(x))
// 3. CORS should be early as it might reject requests before doing unnecessary work
r.Use(cors.Handler(cors.Options{
AllowedOrigins: strings.Split(config.CORSOrigins, ","),
AllowedMethods: []string{"GET", "POST", "DELETE", "PATCH"},
AllowedHeaders: []string{"Accept", "Content-Type"},
AllowCredentials: false,
MaxAge: 300,
}))
if !config.NoCache {
customKeyFunc := func(r *http.Request) uint64 {
token := r.Header.Get("Authorization")
return stampede.StringToHash(r.Method, strings.ToLower(token))
}
cachedMiddleware := stampede.HandlerWithKey(512, time.Duration(config.CacheMaxAge)*time.Second, customKeyFunc, strings.Split(config.CachePrefixes, ",")...)
r.Use(cachedMiddleware)
}
// 4. RequireUserAgent should be early to ensure all other middleware sees the correct user agent
r.Use(customMiddleware.RequireUserAgent)

x := customMiddleware.NewStatistics()

if config.UI {
r.Group(func(r chi.Router) {
Expand All @@ -160,6 +155,55 @@ You can also enable the caching functionality to speed things up.`,
}

r.Group(func(r chi.Router) {
if !config.NoCache {
log.Log.Debug("Setting up cache middleware")
customKeyFunc := func(r *http.Request) uint64 {
token := r.Header.Get("Authorization")
return stampede.StringToHash(r.Method, r.URL.Path, strings.ToLower(token))
}

cachedMiddleware := stampede.HandlerWithKey(
512,
time.Duration(config.CacheMaxAge)*time.Second,
customKeyFunc,
)
log.Log.Debugf("Cache middleware configured with prefixes: %s", config.CachePrefixes)

cachePrefixes := strings.Split(config.CachePrefixes, ",")

r.Use(func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
shouldCache := false
for _, prefix := range cachePrefixes {
if strings.HasPrefix(r.URL.Path, strings.TrimSpace(prefix)) {
shouldCache = true
break
}
}

if shouldCache {
wrapper := customMiddleware.NewResponseWrapper(w)
// Set default cache status before serving
// w.Header().Set("X-Cache", "MISS from gorge")

cachedMiddleware(next).ServeHTTP(wrapper, r)

// Only override if it was served from cache
// TODO: this is not working as expected
if wrapper.WasWritten() {
log.Log.Debugf("Serving cached response for path: %s", r.URL.Path)
w.Header().Set("X-Cache", "HIT from gorge")
} else {
log.Log.Debugf("Cache miss for path: %s", r.URL.Path)
w.Header().Set("X-Cache", "MISS from gorge")
}
} else {
next.ServeHTTP(w, r)
}
})
})
}

if config.FallbackProxyUrl != "" {
proxies := strings.Split(config.FallbackProxyUrl, ",")
slices.Reverse(proxies)
Expand Down Expand Up @@ -190,6 +234,10 @@ You can also enable the caching functionality to speed things up.`,
))
}
}

// StatisticsMiddleware should be last to ensure all other middleware is counted
r.Use(customMiddleware.StatisticsMiddleware(x))

apiRouter := openapi.NewRouter(
openapi.NewModuleOperationsAPIController(moduleService),
openapi.NewReleaseOperationsAPIController(releaseService),
Expand All @@ -212,34 +260,35 @@ You can also enable the caching functionality to speed things up.`,
w.Write([]byte(`{"message": "ok"}`))
})

ctx, restoreDefaultSignalHandling := signal.NotifyContext(context.Background(), syscall.SIGINT, syscall.SIGTERM)
ctx, cancel := context.WithCancel(context.Background())
defer cancel()

// Create signal handling context
sigCtx, restoreDefaultSignalHandling := signal.NotifyContext(ctx, syscall.SIGINT, syscall.SIGTERM)
defer restoreDefaultSignalHandling()
g, gCtx := errgroup.WithContext(ctx)
g, gCtx := errgroup.WithContext(sigCtx)

if err := backend.ConfiguredBackend.LoadModules(); err != nil {
log.Log.Fatal(fmt.Errorf("initial module load failed: %w", err))
}

// if set, continuously check modules directory every ModulesScanSec seconds
// otherwise, check only at startup
if config.ModulesScanSec > 0 {
g.Go(func() error {
// Call LoadModules immediately on startup
if err := backend.ConfiguredBackend.LoadModules(); err != nil {
return err
}
ticker := time.NewTicker(time.Duration(config.ModulesScanSec) * time.Second)
defer ticker.Stop()

for {
select {
case <-gCtx.Done():
log.Log.Debugln("Canceling module scan goroutine")
return nil
case <-time.After(time.Duration(config.ModulesScanSec) * time.Second):
case <-ticker.C:
if err := backend.ConfiguredBackend.LoadModules(); err != nil {
return err
log.Log.Errorf("Failed to load modules: %v", err)
// Continue running instead of failing completely
}
}
}
})
} else {
if err := backend.ConfiguredBackend.LoadModules(); err != nil {
log.Log.Panic(err)
}
}

bindPort := fmt.Sprintf("%s:%d", config.Bind, config.Port)
Expand All @@ -253,20 +302,14 @@ You can also enable the caching functionality to speed things up.`,
wantTLS := config.TlsKeyPath != "" && config.TlsCertPath != ""

if wantTLS {
certificate, err := os.ReadFile(config.TlsCertPath)
cert, err := tls.LoadX509KeyPair(config.TlsCertPath, config.TlsKeyPath)
if err != nil {
log.Log.Fatal(err)
}
key, err := os.ReadFile(config.TlsKeyPath)
if err != nil {
log.Log.Fatal(err)
}
cert, err := tls.X509KeyPair(certificate, key)
if err != nil {
log.Log.Fatal(err)
log.Log.Fatalf("Failed to load TLS certificates: %v", err)
}

tlsConfig := &tls.Config{
Certificates: []tls.Certificate{cert},
MinVersion: tls.VersionTLS12,
}
server.TLSConfig = tlsConfig
}
Expand All @@ -293,12 +336,14 @@ You can also enable the caching functionality to speed things up.`,

g.Go(func() error {
<-gCtx.Done()

log.Log.Debugln("Shutting down server (timeout: 5s)")
gracefullCtx, cancelShutdown := context.WithTimeout(context.Background(), 5*time.Second)
shutdownCtx, cancelShutdown := context.WithTimeout(context.Background(), 10*time.Second)
defer cancelShutdown()

return server.Shutdown(gracefullCtx)
log.Log.Info("Shutting down server...")
if err := server.Shutdown(shutdownCtx); err != nil {
return fmt.Errorf("server shutdown failed: %w", err)
}
return nil
})

if err := g.Wait(); err != nil {
Expand Down Expand Up @@ -335,17 +380,3 @@ func init() {
serveCmd.Flags().BoolVar(&config.NoCache, "no-cache", false, "disables the caching functionality")
serveCmd.Flags().BoolVar(&config.ImportProxiedReleases, "import-proxied-releases", false, "add every proxied modules to local store")
}

func checkModules(sleepSeconds int) {
for {
err := backend.ConfiguredBackend.LoadModules()
if err != nil {
log.Log.Fatal(err)
}
if sleepSeconds > 0 {
time.Sleep(time.Duration(sleepSeconds) * time.Second)
} else {
break
}
}
}
1 change: 1 addition & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ require (
github.com/sourcegraph/conc v0.3.0 // indirect
github.com/spf13/afero v1.11.0 // indirect
github.com/spf13/cast v1.7.0 // indirect
github.com/stretchr/testify v1.10.0 // indirect
github.com/subosito/gotenv v1.6.0 // indirect
go.uber.org/multierr v1.11.0 // indirect
golang.org/x/crypto v0.31.0 // indirect
Expand Down
Loading

0 comments on commit e0809a4

Please sign in to comment.