diff --git a/internal/config/config.go b/internal/config/config.go index fdf3ee97..a7d7bafb 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -11,6 +11,7 @@ import ( "strings" "github.com/NHAS/wag/internal/acls" + "github.com/NHAS/wag/internal/data/validators" "github.com/NHAS/wag/internal/routetypes" "github.com/NHAS/wag/pkg/control" "golang.zx2c4.com/wireguard/wgctrl/wgtypes" @@ -242,7 +243,7 @@ func load(path string) (c Config, err error) { *c.NAT = true } - err = validExternalAddresses(c.ExternalAddress) + err = validators.ValidExternalAddresses(c.ExternalAddress) if err != nil { return c, err } @@ -365,26 +366,6 @@ func validateDns(input []string) (newDnsEntries []string, err error) { return } -func validExternalAddresses(ExternalAddress string) error { - if len(ExternalAddress) == 0 { - return errors.New("invalid ExternalAddress is empty") - } - - if net.ParseIP(ExternalAddress) == nil { - - addresses, err := net.LookupIP(ExternalAddress) - if err != nil { - return errors.New("invalid ExternalAddress: " + ExternalAddress + " unable to lookup as domain") - } - - if len(addresses) == 0 { - return errors.New("invalid ExternalAddress: " + ExternalAddress + " not IPv4 or IPv6 external addresses found") - } - } - - return nil -} - func Load(path string) error { var err error diff --git a/internal/data/config.go b/internal/data/config.go index 45f683da..f2547f04 100644 --- a/internal/data/config.go +++ b/internal/data/config.go @@ -8,6 +8,7 @@ import ( "net/url" "strings" + "github.com/NHAS/wag/internal/data/validators" clientv3 "go.etcd.io/etcd/client/v3" ) @@ -285,6 +286,11 @@ func GetHelpMail() string { } func SetExternalAddress(externalAddress string) error { + + if err := validators.ValidExternalAddresses(externalAddress); err != nil { + return err + } + data, _ := json.Marshal(externalAddress) _, err := etcd.Put(context.Background(), externalAddressKey, string(data)) return err diff --git a/internal/data/validators/config.go b/internal/data/validators/config.go new file mode 100644 index 00000000..b51a8d06 --- /dev/null +++ b/internal/data/validators/config.go @@ -0,0 +1,31 @@ +package validators + +import ( + "errors" + "net" +) + +func ValidExternalAddresses(ExternalAddress string) error { + if len(ExternalAddress) == 0 { + return errors.New("invalid ExternalAddress is empty") + } + + host, _, err := net.SplitHostPort(ExternalAddress) + if err == nil { + // If the external address has a port, split it off and use that as the external address to check + ExternalAddress = host + } + + if net.ParseIP(ExternalAddress) == nil { + + addresses, err := net.LookupIP(ExternalAddress) + if err != nil { + return errors.New("invalid ExternalAddress: " + ExternalAddress + " unable to lookup as domain") + } + + if len(addresses) == 0 { + return errors.New("invalid ExternalAddress: " + ExternalAddress + " not IPv4 or IPv6 external addresses found") + } + } + return nil +} diff --git a/internal/utils/utils.go b/internal/utils/utils.go index 1b2ed4f3..04ec1862 100644 --- a/internal/utils/utils.go +++ b/internal/utils/utils.go @@ -15,6 +15,7 @@ func GetIP(addr string) string { return addr[:i] } } + return addr } diff --git a/internal/webserver/web.go b/internal/webserver/web.go index 8ce4d9fe..fc48602a 100644 --- a/internal/webserver/web.go +++ b/internal/webserver/web.go @@ -511,7 +511,7 @@ func registerDevice(w http.ResponseWriter, r *http.Request) { acl := data.GetEffectiveAcl(username) - wgPublicKey, _, err := router.ServerDetails() + wgPublicKey, wgPort, err := router.ServerDetails() if err != nil { log.Println(username, remoteAddr, "unable access wireguard device: ", err) http.Error(w, "Server Error", 500) @@ -558,13 +558,21 @@ func registerDevice(w http.ResponseWriter, r *http.Request) { ClientPresharedKey: presharedKey, } - wireguardInterface.ServerAddress, err = data.GetExternalAddress() + externalAddress, err := data.GetExternalAddress() if err != nil { log.Println(username, remoteAddr, "unable to get server external address from datastore: ", err) http.Error(w, "Server Error", 500) return } + // If the external address defined in the config has a port, use that, otherwise defaultly add the same port as the wireguard device + _, _, err = net.SplitHostPort(externalAddress) + if err != nil { + externalAddress = fmt.Sprintf("%s:%d", externalAddress, wgPort) + } + + wireguardInterface.ServerAddress = externalAddress + if r.URL.Query().Get("type") == "mobile" { w.Header().Set("Content-Type", "text/html; charset=UTF-8")