Skip to content

Commit

Permalink
allow disabling vtun on the server
Browse files Browse the repository at this point in the history
  • Loading branch information
USA-RedDragon committed Feb 21, 2024
1 parent 91997c9 commit 3eb0b2f
Show file tree
Hide file tree
Showing 4 changed files with 80 additions and 49 deletions.
47 changes: 29 additions & 18 deletions cmd/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,11 @@ func runServer(cmd *cobra.Command, _ []string) error {
defer cancel()

olsrdExitedChan := olsrd.Run(ctx)
vtunExitedChan := vtun.Run(ctx)

vtunExitedChan := make(chan struct{})
if !config.DisableVTun {
vtunExitedChan = vtun.Run(ctx)
}

// Start the metrics server
go metrics.CreateMetricsServer(config, cmd.Root().Version)
Expand Down Expand Up @@ -100,10 +104,13 @@ func runServer(cmd *cobra.Command, _ []string) error {
}
log.Printf("Interface watcher started")

// Start the vtun client watcher
vtunClientWatcher := vtun.NewVTunClientWatcher(db, config)
vtunClientWatcher.Run()
log.Printf("VTun client watcher started")
var vtunClientWatcher *vtun.VTunClientWatcher
if !config.DisableVTun {
// Start the vtun client watcher
vtunClientWatcher = vtun.NewVTunClientWatcher(db, config)
vtunClientWatcher.Run()
log.Printf("VTun client watcher started")
}

// Start the server
srv := server.NewServer(config, db, ifWatcher.Stats, eventBus.GetChannel(), vtunClientWatcher, wireguardManager)
Expand Down Expand Up @@ -140,16 +147,18 @@ func runServer(cmd *cobra.Command, _ []string) error {
}
})

errGrp.Go(func() error {
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
select {
case <-ctx.Done():
return fmt.Errorf("vtund did not exit in time")
case <-vtunExitedChan:
return nil
}
})
if !config.DisableVTun {
errGrp.Go(func() error {
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
select {
case <-ctx.Done():
return fmt.Errorf("vtund did not exit in time")
case <-vtunExitedChan:
return nil
}
})
}

errGrp.Go(func() error {
return wireguardManager.Stop()
Expand All @@ -159,9 +168,11 @@ func runServer(cmd *cobra.Command, _ []string) error {
return srv.Stop()
})

errGrp.Go(func() error {
return vtunClientWatcher.Stop()
})
if !config.DisableVTun {
errGrp.Go(func() error {
return vtunClientWatcher.Stop()
})
}

errGrp.Go(func() error {
return ifWatcher.Stop()
Expand Down
2 changes: 2 additions & 0 deletions internal/config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ type Config struct {
Latitude string
Longitude string
Gridsquare string
DisableVTun bool
}

func loadConfig() Config {
Expand Down Expand Up @@ -97,6 +98,7 @@ func loadConfig() Config {
Latitude: os.Getenv("SERVER_LAT"),
Longitude: os.Getenv("SERVER_LON"),
Gridsquare: os.Getenv("SERVER_GRIDSQUARE"),
DisableVTun: os.Getenv("DISABLE_VTUN") != "",
}

if tmpConfig.VTUNStartingAddress == "" {
Expand Down
76 changes: 46 additions & 30 deletions internal/server/api/controllers/v1/tunnels.go
Original file line number Diff line number Diff line change
Expand Up @@ -206,11 +206,14 @@ func POSTTunnel(c *gin.Context) {
return
}

vtunClientWatcher, ok := c.MustGet("VTunClientWatcher").(*vtun.VTunClientWatcher)
if !ok {
fmt.Println("DELETETunnel: Unable to get VTunClientWatcher from context")
c.JSON(http.StatusInternalServerError, gin.H{"error": "Try again later"})
return
var vtunClientWatcher *vtun.VTunClientWatcher
if !config.DisableVTun {
vtunClientWatcher, ok = c.MustGet("VTunClientWatcher").(*vtun.VTunClientWatcher)
if !ok {
fmt.Println("DELETETunnel: Unable to get VTunClientWatcher from context")
c.JSON(http.StatusInternalServerError, gin.H{"error": "Try again later"})
return
}
}

wireguardManager, ok := c.MustGet("WireguardManager").(*wireguard.Manager)
Expand All @@ -231,6 +234,11 @@ func POSTTunnel(c *gin.Context) {
return
}

if !json.Wireguard && config.DisableVTun {
c.JSON(http.StatusBadRequest, gin.H{"error": "VTun is disabled"})
return
}

if !json.Client {
json.Hostname = strings.ToUpper(json.Hostname)
isValid, errString := json.IsValidHostname()
Expand Down Expand Up @@ -463,7 +471,7 @@ func POSTTunnel(c *gin.Context) {
return
}

if !tunnel.Wireguard {
if !tunnel.Wireguard && !config.DisableVTun {
err = vtun.GenerateAndSaveClient(config, db)
if err != nil {
fmt.Printf("POSTTunnel: Error generating vtun client config: %v\n", err)
Expand All @@ -477,7 +485,7 @@ func POSTTunnel(c *gin.Context) {
c.JSON(http.StatusInternalServerError, gin.H{"error": "Error reloading vtun client"})
return
}
} else {
} else if tunnel.Wireguard {
err = wireguardManager.AddPeer(tunnel)
if err != nil {
fmt.Printf("POSTTunnel: Error adding wireguard peer: %v\n", err)
Expand Down Expand Up @@ -527,11 +535,14 @@ func PATCHTunnel(c *gin.Context) {
return
}

vtunClientWatcher, ok := c.MustGet("VTunClientWatcher").(*vtun.VTunClientWatcher)
if !ok {
fmt.Println("PATCHTunnel: Unable to get VTunClientWatcher from context")
c.JSON(http.StatusInternalServerError, gin.H{"error": "Try again later"})
return
var vtunClientWatcher *vtun.VTunClientWatcher
if !config.DisableVTun {
vtunClientWatcher, ok = c.MustGet("VTunClientWatcher").(*vtun.VTunClientWatcher)
if !ok {
fmt.Println("PATCHTunnel: Unable to get VTunClientWatcher from context")
c.JSON(http.StatusInternalServerError, gin.H{"error": "Try again later"})
return
}
}

wireguardManager, ok := c.MustGet("WireguardManager").(*wireguard.Manager)
Expand Down Expand Up @@ -656,7 +667,7 @@ func PATCHTunnel(c *gin.Context) {
return
}

if !tunnel.Wireguard {
if !tunnel.Wireguard && !config.DisableVTun {
err = vtun.GenerateAndSave(config, db)
if err != nil {
fmt.Printf("PATCHTunnel: Error generating vtun config: %v\n", err)
Expand All @@ -677,7 +688,7 @@ func PATCHTunnel(c *gin.Context) {
c.JSON(http.StatusInternalServerError, gin.H{"error": "Error reloading vtun client"})
return
}
} else {
} else if tunnel.Wireguard {
err = wireguardManager.RemovePeer(origTunnel)
if err != nil {
fmt.Printf("PATCHTunnel: Error adding wireguard peer: %v\n", err)
Expand Down Expand Up @@ -732,11 +743,14 @@ func DELETETunnel(c *gin.Context) {
return
}

vtunClientWatcher, ok := c.MustGet("VTunClientWatcher").(*vtun.VTunClientWatcher)
if !ok {
fmt.Println("DELETETunnel: Unable to get VTunClientWatcher from context")
c.JSON(http.StatusInternalServerError, gin.H{"error": "Try again later"})
return
var vtunClientWatcher *vtun.VTunClientWatcher
if !config.DisableVTun {
vtunClientWatcher, ok = c.MustGet("VTunClientWatcher").(*vtun.VTunClientWatcher)
if !ok {
fmt.Println("DELETETunnel: Unable to get VTunClientWatcher from context")
c.JSON(http.StatusInternalServerError, gin.H{"error": "Try again later"})
return
}
}

wireguardManager, ok := c.MustGet("WireguardManager").(*wireguard.Manager)
Expand Down Expand Up @@ -807,18 +821,20 @@ func DELETETunnel(c *gin.Context) {
return
}

err = vtun.Reload()
if err != nil {
fmt.Printf("Error reloading vtun: %v\n", err)
c.JSON(http.StatusInternalServerError, gin.H{"error": "Error reloading vtun"})
return
}
if !config.DisableVTun {
err = vtun.Reload()
if err != nil {
fmt.Printf("Error reloading vtun: %v\n", err)
c.JSON(http.StatusInternalServerError, gin.H{"error": "Error reloading vtun"})
return
}

err = vtun.ReloadAllClients(db, vtunClientWatcher)
if err != nil {
fmt.Printf("DELETETunnel: Error reloading vtun client: %v\n", err)
c.JSON(http.StatusInternalServerError, gin.H{"error": "Error reloading vtun client"})
return
err = vtun.ReloadAllClients(db, vtunClientWatcher)
if err != nil {
fmt.Printf("DELETETunnel: Error reloading vtun client: %v\n", err)
c.JSON(http.StatusInternalServerError, gin.H{"error": "Error reloading vtun client"})
return
}
}

err = olsrd.Reload()
Expand Down
4 changes: 3 additions & 1 deletion internal/server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,9 @@ func (s *Server) addMiddleware(r *gin.Engine, version string) {
r.Use(middleware.DatabaseProvider(s.db))
r.Use(middleware.OLSRDProvider(olsrd.NewHostsParser()))
r.Use(middleware.OLSRDServicesProvider(olsrd.NewServicesParser()))
r.Use(middleware.VTunClientWatcherProvider(s.vtunClientWatcher))
if !s.config.DisableVTun {
r.Use(middleware.VTunClientWatcherProvider(s.vtunClientWatcher))
}
r.Use(middleware.WireguardManagerProvider(s.wireguardManager))
r.Use(middleware.NetworkStats(s.stats))
r.Use(middleware.PaginatedDatabaseProvider(s.db, middleware.PaginationConfig{}))
Expand Down

0 comments on commit 3eb0b2f

Please sign in to comment.