Skip to content

Commit

Permalink
Cascade errors back up to main(), rather than exiting deep in the stack.
Browse files Browse the repository at this point in the history
  • Loading branch information
xxxserxxx committed Aug 11, 2022
1 parent da76ddb commit 9bf5693
Show file tree
Hide file tree
Showing 7 changed files with 63 additions and 32 deletions.
19 changes: 14 additions & 5 deletions cmd/cli/add.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,12 @@ import (
)

// Add prompts for the required information and creates a new peer
func Add(hostname, owner, description string, confirm bool) {
func Add(hostname, owner, description string, confirm bool) error {
// TODO accept existing pubkey
config, err := LoadConfigFile()
check(err, "failed to load configuration file")
if err != nil {
return wrapError(err, "failed to load configuration file")
}
server := GetServer(config)

if owner == "" {
Expand All @@ -31,7 +33,9 @@ func Add(hostname, owner, description string, confirm bool) {
fmt.Fprintln(os.Stderr)

peer, err := lib.NewPeer(server, owner, hostname, description)
check(err, "failed to get new peer")
if err != nil {
return wrapError(err, "failed to get new peer")
}

// TODO Some kind of recovery here would be nice, to avoid
// leaving things in a potential broken state
Expand All @@ -41,12 +45,17 @@ func Add(hostname, owner, description string, confirm bool) {
peerType := viper.GetString("output")

peerConfigBytes, err := lib.AsciiPeerConfig(peer, peerType, *server)
check(err, "failed to get peer configuration")
if err != nil {
return wrapError(err, "failed to get peer configuration")
}
os.Stdout.Write(peerConfigBytes.Bytes())

config.MustSave()

server = GetServer(config)
err = server.ConfigureDevice()
check(err, "failed to configure device")
if err != nil {
return wrapError(err, "failed to configure device")
}
return nil
}
21 changes: 15 additions & 6 deletions cmd/cli/regenerate.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ import (
"github.com/spf13/viper"
)

func Regenerate(hostname string, confirm bool) {
func Regenerate(hostname string, confirm bool) error {
config := MustLoadConfigFile()
server := GetServer(config)

Expand All @@ -21,22 +21,30 @@ func Regenerate(hostname string, confirm bool) {
for _, peer := range server.Peers {
if peer.Hostname == hostname {
privateKey, err := lib.GenerateJSONPrivateKey()
check(err, "failed to generate private key")
if err != nil {
return wrapError(err, "failed to generate private key")
}

preshareKey, err := lib.GenerateJSONKey()
check(err, "failed to generate preshared key")
if err != nil {
return wrapError(err, "failed to generate preshared key")
}

peer.PrivateKey = privateKey
peer.PublicKey = privateKey.PublicKey()
peer.PresharedKey = preshareKey

err = config.RemovePeer(hostname)
check(err, "failed to regenerate peer")
if err != nil {
return wrapError(err, "failed to regenerate peer")
}

peerType := viper.GetString("output")

peerConfigBytes, err := lib.AsciiPeerConfig(peer, peerType, *server)
check(err, "failed to get peer configuration")
if err != nil {
return wrapError(err, "failed to get peer configuration")
}
os.Stdout.Write(peerConfigBytes.Bytes())
found = true
config.MustAddPeer(peer)
Expand All @@ -46,11 +54,12 @@ func Regenerate(hostname string, confirm bool) {
}

if !found {
ExitFail(fmt.Sprintf("unknown hostname: %s", hostname))
return fmt.Errorf("unknown hostname: %s", hostname)
}

// Get a new server configuration so we can update the wg interface with the new peer details
server = GetServer(config)
config.MustSave()
server.ConfigureDevice()
return nil
}
11 changes: 8 additions & 3 deletions cmd/cli/remove.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,13 @@ package cli

import "fmt"

func Remove(hostname string, confirm bool) {
func Remove(hostname string, confirm bool) error {
conf := MustLoadConfigFile()

err := conf.RemovePeer(hostname)
check(err, "failed to update config")
if err != nil {
return wrapError(err, "failed to update config")
}

if !confirm {
ConfirmOrAbort("Do you really want to remove %s?", hostname)
Expand All @@ -16,5 +18,8 @@ func Remove(hostname string, confirm bool) {
server := GetServer(conf)

err = server.ConfigureDevice()
check(err, fmt.Sprintf("failed to sync server config to wg interface: %s", server.InterfaceName))
if err != nil {
return wrapError(err, fmt.Sprintf("failed to sync server config to wg interface: %s", server.InterfaceName))
}
return nil
}
6 changes: 4 additions & 2 deletions cmd/cli/report.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package cli

import (
"encoding/json"
"fmt"
"io/ioutil"
"net"
"os"
Expand Down Expand Up @@ -69,7 +70,7 @@ type PeerReport struct {
TransmitBytesSI string
}

func GenerateReport() {
func GenerateReport() error {
conf := MustLoadConfigFile()

wg, err := wgctrl.New()
Expand All @@ -79,12 +80,13 @@ func GenerateReport() {
dev, err := wg.Device(conf.InterfaceName)

if err != nil {
ExitFail("Could not retrieve device '%s' (%v)", conf.InterfaceName, err)
return wrapError(err, fmt.Sprintf("Could not retrieve device '%s'", conf.InterfaceName))
}

oldReport := MustLoadDsnetReport()
report := GetReport(dev, conf, oldReport)
report.MustSave()
return nil
}

func GetReport(dev *wgtypes.Device, conf *DsnetConfig, oldReport *DsnetReport) DsnetReport {
Expand Down
11 changes: 8 additions & 3 deletions cmd/cli/sync.go
Original file line number Diff line number Diff line change
@@ -1,10 +1,15 @@
package cli

func Sync() {
func Sync() error {
// TODO check device settings first
conf, err := LoadConfigFile()
check(err, "failed to load configuration file")
if err != nil {
return wrapError(err, "failed to load configuration file")
}
server := GetServer(conf)
err = server.ConfigureDevice()
check(err, "failed to sync device configuration")
if err != nil {
return wrapError(err, "failed to sync device configuration")
}
return nil
}
4 changes: 4 additions & 0 deletions cmd/cli/util.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,10 @@ func ExitFail(format string, a ...interface{}) {
os.Exit(1)
}

func wrapError(err error, s string) error {
return fmt.Errorf("\033[31m%s - %s\033[0m\n", err, s)
}

func MustPromptString(prompt string, required bool) string {
reader := bufio.NewReader(os.Stdin)
var text string
Expand Down
23 changes: 10 additions & 13 deletions cmd/root.go
Original file line number Diff line number Diff line change
Expand Up @@ -76,8 +76,8 @@ var (
}
return nil
},
Run: func(cmd *cobra.Command, args []string) {
cli.Add(args[0], owner, description, confirm)
RunE: func(cmd *cobra.Command, args []string) error {
return cli.Add(args[0], owner, description, confirm)
},
}

Expand All @@ -90,24 +90,24 @@ var (
}
return nil
},
Run: func(cmd *cobra.Command, args []string) {
cli.Regenerate(args[0], confirm)
RunE: func(cmd *cobra.Command, args []string) error {
return cli.Regenerate(args[0], confirm)
},
}

syncCmd = &cobra.Command{
Use: "sync",
Short: fmt.Sprintf("Update wireguard configuration from %s after validating", viper.GetString("config_file")),
Run: func(cmd *cobra.Command, args []string) {
cli.Sync()
RunE: func(cmd *cobra.Command, args []string) error {
return cli.Sync()
},
}

reportCmd = &cobra.Command{
Use: "report",
Short: fmt.Sprintf("Generate a JSON status report to the location configured in %s.", viper.GetString("config_file")),
Run: func(cmd *cobra.Command, args []string) {
cli.GenerateReport()
RunE: func(cmd *cobra.Command, args []string) error {
return cli.GenerateReport()
},
}

Expand All @@ -122,8 +122,8 @@ var (

return nil
},
Run: func(cmd *cobra.Command, args []string) {
cli.Remove(args[0], confirm)
RunE: func(cmd *cobra.Command, args []string) error {
return cli.Remove(args[0], confirm)
},
}

Expand Down Expand Up @@ -181,8 +181,5 @@ func main() {
if err := rootCmd.Execute(); err != nil {
cli.ExitFail(err.Error())
}
if error_encountered {
os.Exit(1)
}
os.Exit(0)
}

0 comments on commit 9bf5693

Please sign in to comment.