From 3eb0b2f83256e047f6105e332cfcbce0a1978a28 Mon Sep 17 00:00:00 2001 From: Jacob McSwain Date: Wed, 21 Feb 2024 09:15:31 -0600 Subject: [PATCH] allow disabling vtun on the server --- cmd/server.go | 47 +++++++----- internal/config/config.go | 2 + internal/server/api/controllers/v1/tunnels.go | 76 +++++++++++-------- internal/server/server.go | 4 +- 4 files changed, 80 insertions(+), 49 deletions(-) diff --git a/cmd/server.go b/cmd/server.go index 751a6686..324144e7 100644 --- a/cmd/server.go +++ b/cmd/server.go @@ -51,7 +51,11 @@ func runServer(cmd *cobra.Command, _ []string) error { defer cancel() olsrdExitedChan := olsrd.Run(ctx) - vtunExitedChan := vtun.Run(ctx) + + vtunExitedChan := make(chan struct{}) + if !config.DisableVTun { + vtunExitedChan = vtun.Run(ctx) + } // Start the metrics server go metrics.CreateMetricsServer(config, cmd.Root().Version) @@ -100,10 +104,13 @@ func runServer(cmd *cobra.Command, _ []string) error { } log.Printf("Interface watcher started") - // Start the vtun client watcher - vtunClientWatcher := vtun.NewVTunClientWatcher(db, config) - vtunClientWatcher.Run() - log.Printf("VTun client watcher started") + var vtunClientWatcher *vtun.VTunClientWatcher + if !config.DisableVTun { + // Start the vtun client watcher + vtunClientWatcher = vtun.NewVTunClientWatcher(db, config) + vtunClientWatcher.Run() + log.Printf("VTun client watcher started") + } // Start the server srv := server.NewServer(config, db, ifWatcher.Stats, eventBus.GetChannel(), vtunClientWatcher, wireguardManager) @@ -140,16 +147,18 @@ func runServer(cmd *cobra.Command, _ []string) error { } }) - errGrp.Go(func() error { - ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) - defer cancel() - select { - case <-ctx.Done(): - return fmt.Errorf("vtund did not exit in time") - case <-vtunExitedChan: - return nil - } - }) + if !config.DisableVTun { + errGrp.Go(func() error { + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + select { + case <-ctx.Done(): + return fmt.Errorf("vtund did not exit in time") + case <-vtunExitedChan: + return nil + } + }) + } errGrp.Go(func() error { return wireguardManager.Stop() @@ -159,9 +168,11 @@ func runServer(cmd *cobra.Command, _ []string) error { return srv.Stop() }) - errGrp.Go(func() error { - return vtunClientWatcher.Stop() - }) + if !config.DisableVTun { + errGrp.Go(func() error { + return vtunClientWatcher.Stop() + }) + } errGrp.Go(func() error { return ifWatcher.Stop() diff --git a/internal/config/config.go b/internal/config/config.go index 13aba62c..a15f710a 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -44,6 +44,7 @@ type Config struct { Latitude string Longitude string Gridsquare string + DisableVTun bool } func loadConfig() Config { @@ -97,6 +98,7 @@ func loadConfig() Config { Latitude: os.Getenv("SERVER_LAT"), Longitude: os.Getenv("SERVER_LON"), Gridsquare: os.Getenv("SERVER_GRIDSQUARE"), + DisableVTun: os.Getenv("DISABLE_VTUN") != "", } if tmpConfig.VTUNStartingAddress == "" { diff --git a/internal/server/api/controllers/v1/tunnels.go b/internal/server/api/controllers/v1/tunnels.go index 371ce076..28b805f9 100644 --- a/internal/server/api/controllers/v1/tunnels.go +++ b/internal/server/api/controllers/v1/tunnels.go @@ -206,11 +206,14 @@ func POSTTunnel(c *gin.Context) { return } - vtunClientWatcher, ok := c.MustGet("VTunClientWatcher").(*vtun.VTunClientWatcher) - if !ok { - fmt.Println("DELETETunnel: Unable to get VTunClientWatcher from context") - c.JSON(http.StatusInternalServerError, gin.H{"error": "Try again later"}) - return + var vtunClientWatcher *vtun.VTunClientWatcher + if !config.DisableVTun { + vtunClientWatcher, ok = c.MustGet("VTunClientWatcher").(*vtun.VTunClientWatcher) + if !ok { + fmt.Println("DELETETunnel: Unable to get VTunClientWatcher from context") + c.JSON(http.StatusInternalServerError, gin.H{"error": "Try again later"}) + return + } } wireguardManager, ok := c.MustGet("WireguardManager").(*wireguard.Manager) @@ -231,6 +234,11 @@ func POSTTunnel(c *gin.Context) { return } + if !json.Wireguard && config.DisableVTun { + c.JSON(http.StatusBadRequest, gin.H{"error": "VTun is disabled"}) + return + } + if !json.Client { json.Hostname = strings.ToUpper(json.Hostname) isValid, errString := json.IsValidHostname() @@ -463,7 +471,7 @@ func POSTTunnel(c *gin.Context) { return } - if !tunnel.Wireguard { + if !tunnel.Wireguard && !config.DisableVTun { err = vtun.GenerateAndSaveClient(config, db) if err != nil { fmt.Printf("POSTTunnel: Error generating vtun client config: %v\n", err) @@ -477,7 +485,7 @@ func POSTTunnel(c *gin.Context) { c.JSON(http.StatusInternalServerError, gin.H{"error": "Error reloading vtun client"}) return } - } else { + } else if tunnel.Wireguard { err = wireguardManager.AddPeer(tunnel) if err != nil { fmt.Printf("POSTTunnel: Error adding wireguard peer: %v\n", err) @@ -527,11 +535,14 @@ func PATCHTunnel(c *gin.Context) { return } - vtunClientWatcher, ok := c.MustGet("VTunClientWatcher").(*vtun.VTunClientWatcher) - if !ok { - fmt.Println("PATCHTunnel: Unable to get VTunClientWatcher from context") - c.JSON(http.StatusInternalServerError, gin.H{"error": "Try again later"}) - return + var vtunClientWatcher *vtun.VTunClientWatcher + if !config.DisableVTun { + vtunClientWatcher, ok = c.MustGet("VTunClientWatcher").(*vtun.VTunClientWatcher) + if !ok { + fmt.Println("PATCHTunnel: Unable to get VTunClientWatcher from context") + c.JSON(http.StatusInternalServerError, gin.H{"error": "Try again later"}) + return + } } wireguardManager, ok := c.MustGet("WireguardManager").(*wireguard.Manager) @@ -656,7 +667,7 @@ func PATCHTunnel(c *gin.Context) { return } - if !tunnel.Wireguard { + if !tunnel.Wireguard && !config.DisableVTun { err = vtun.GenerateAndSave(config, db) if err != nil { fmt.Printf("PATCHTunnel: Error generating vtun config: %v\n", err) @@ -677,7 +688,7 @@ func PATCHTunnel(c *gin.Context) { c.JSON(http.StatusInternalServerError, gin.H{"error": "Error reloading vtun client"}) return } - } else { + } else if tunnel.Wireguard { err = wireguardManager.RemovePeer(origTunnel) if err != nil { fmt.Printf("PATCHTunnel: Error adding wireguard peer: %v\n", err) @@ -732,11 +743,14 @@ func DELETETunnel(c *gin.Context) { return } - vtunClientWatcher, ok := c.MustGet("VTunClientWatcher").(*vtun.VTunClientWatcher) - if !ok { - fmt.Println("DELETETunnel: Unable to get VTunClientWatcher from context") - c.JSON(http.StatusInternalServerError, gin.H{"error": "Try again later"}) - return + var vtunClientWatcher *vtun.VTunClientWatcher + if !config.DisableVTun { + vtunClientWatcher, ok = c.MustGet("VTunClientWatcher").(*vtun.VTunClientWatcher) + if !ok { + fmt.Println("DELETETunnel: Unable to get VTunClientWatcher from context") + c.JSON(http.StatusInternalServerError, gin.H{"error": "Try again later"}) + return + } } wireguardManager, ok := c.MustGet("WireguardManager").(*wireguard.Manager) @@ -807,18 +821,20 @@ func DELETETunnel(c *gin.Context) { return } - err = vtun.Reload() - if err != nil { - fmt.Printf("Error reloading vtun: %v\n", err) - c.JSON(http.StatusInternalServerError, gin.H{"error": "Error reloading vtun"}) - return - } + if !config.DisableVTun { + err = vtun.Reload() + if err != nil { + fmt.Printf("Error reloading vtun: %v\n", err) + c.JSON(http.StatusInternalServerError, gin.H{"error": "Error reloading vtun"}) + return + } - err = vtun.ReloadAllClients(db, vtunClientWatcher) - if err != nil { - fmt.Printf("DELETETunnel: Error reloading vtun client: %v\n", err) - c.JSON(http.StatusInternalServerError, gin.H{"error": "Error reloading vtun client"}) - return + err = vtun.ReloadAllClients(db, vtunClientWatcher) + if err != nil { + fmt.Printf("DELETETunnel: Error reloading vtun client: %v\n", err) + c.JSON(http.StatusInternalServerError, gin.H{"error": "Error reloading vtun client"}) + return + } } err = olsrd.Reload() diff --git a/internal/server/server.go b/internal/server/server.go index 90695f46..1f656ea5 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -131,7 +131,9 @@ func (s *Server) addMiddleware(r *gin.Engine, version string) { r.Use(middleware.DatabaseProvider(s.db)) r.Use(middleware.OLSRDProvider(olsrd.NewHostsParser())) r.Use(middleware.OLSRDServicesProvider(olsrd.NewServicesParser())) - r.Use(middleware.VTunClientWatcherProvider(s.vtunClientWatcher)) + if !s.config.DisableVTun { + r.Use(middleware.VTunClientWatcherProvider(s.vtunClientWatcher)) + } r.Use(middleware.WireguardManagerProvider(s.wireguardManager)) r.Use(middleware.NetworkStats(s.stats)) r.Use(middleware.PaginatedDatabaseProvider(s.db, middleware.PaginationConfig{}))