diff --git a/.idea/.gitignore b/.idea/.gitignore new file mode 100644 index 00000000..13566b81 --- /dev/null +++ b/.idea/.gitignore @@ -0,0 +1,8 @@ +# Default ignored files +/shelf/ +/workspace.xml +# Editor-based HTTP Client requests +/httpRequests/ +# Datasource local storage ignored files +/dataSources/ +/dataSources.local.xml diff --git a/.idea/modules.xml b/.idea/modules.xml new file mode 100644 index 00000000..1d3a2a3d --- /dev/null +++ b/.idea/modules.xml @@ -0,0 +1,8 @@ + + + + + + + + \ No newline at end of file diff --git a/.idea/vcs.xml b/.idea/vcs.xml new file mode 100644 index 00000000..35eb1ddf --- /dev/null +++ b/.idea/vcs.xml @@ -0,0 +1,6 @@ + + + + + + \ No newline at end of file diff --git a/.idea/wag.iml b/.idea/wag.iml new file mode 100644 index 00000000..5e764c4f --- /dev/null +++ b/.idea/wag.iml @@ -0,0 +1,9 @@ + + + + + + + + + \ No newline at end of file diff --git a/commands/cleanup.go b/commands/cleanup.go index 81c02b83..70541830 100644 --- a/commands/cleanup.go +++ b/commands/cleanup.go @@ -3,13 +3,12 @@ package commands import ( "flag" "fmt" - "log" - "os" - "os/exec" - "github.com/NHAS/wag/internal/config" + "github.com/NHAS/wag/internal/data" "github.com/NHAS/wag/internal/router" "github.com/NHAS/wag/pkg/control/server" + "log" + "os" ) type cleanup struct { @@ -61,11 +60,12 @@ func (g *cleanup) Run() error { if result != "0" && result != "3" { log.Println("Cleaning up") + router.TearDown(true) server.TearDown() - exec.Command("/usr/bin/wg-quick", "save", config.Values.Wireguard.DevName).Run() + data.TearDown() - return exec.Command("/usr/bin/wg-quick", "down", config.Values.Wireguard.DevName).Run() + return nil } diff --git a/commands/start.go b/commands/start.go index 8ba72127..fe4e24e8 100644 --- a/commands/start.go +++ b/commands/start.go @@ -105,7 +105,6 @@ func teardown(force bool) { ui.Teardown() webserver.Teardown() - } func clusterState(noIptables bool, errorChan chan<- error) func(string) { @@ -126,11 +125,14 @@ func clusterState(noIptables bool, errorChan chan<- error) func(string) { switch stateText { case "dead": if !wasDead { - log.Println("Tearing down node") - - teardown(false) - log.Println("Tear down complete") + if !config.Values.Clustering.Witness { + log.Println("Tearing down node") + teardown(false) + log.Println("Tear down complete") + } else { + log.Println("refusing to tear down witness node (nothing to tear down)") + } // Only teardown if we were at one point alive wasDead = true diff --git a/internal/config/config.go b/internal/config/config.go index a7d7bafb..7073e870 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -390,7 +390,7 @@ func parseAddress(address string) ([]string, error) { return nil, fmt.Errorf("no addresses for %s", address) } - output := []string{} + var output []string addedSomething := false for _, addr := range addresses { if addr.To4() != nil { diff --git a/internal/data/clustering.go b/internal/data/clustering.go index f476cf3a..2545b7a4 100644 --- a/internal/data/clustering.go +++ b/internal/data/clustering.go @@ -33,8 +33,8 @@ type NodeControlRequest struct { Action string } -func GetServerID() string { - return etcdServer.Server.ID().String() +func GetServerID() types.ID { + return etcdServer.Server.ID() } func GetLeader() types.ID { @@ -49,11 +49,7 @@ func IsLearner() bool { return etcdServer.Server.IsLearner() } -func IsLeader() bool { - return etcdServer.Server.Leader() == etcdServer.Server.ID() -} - -// Called on a leader node, to transfer ownership to another node (demoted) +// StepDown when called on a leader node, to transfer ownership to another node (demoted) func StepDown() error { return etcdServer.Server.TransferLeadership() } @@ -73,7 +69,7 @@ func GetLastPing(idHex string) (time.Time, error) { return time.Time{}, errors.New("id is not part of cluster") } - lastPing, err := etcd.Get(context.Background(), path.Join(NodeEvents, idHex, "ping")) + lastPing, err := etcd.Get(context.Background(), path.Join(NodeInfo, idHex, "ping")) if err != nil { return time.Time{}, err } @@ -94,18 +90,52 @@ func GetLastPing(idHex string) (time.Time, error) { return t, nil } -func SetDrained(idHex string, on bool) error { +func SetWitness(on bool) error { + if on { + _, err := etcd.Put(context.Background(), path.Join(NodeInfo, GetServerID().String(), "witness"), fmt.Sprintf("%t", on)) + return err + } + + _, err := etcd.Delete(context.Background(), path.Join(NodeInfo, GetServerID().String(), "witness")) + return err +} + +func IsWitness(idHex string) (bool, error) { _, err := strconv.ParseUint(idHex, 16, 64) + if err != nil { + return false, fmt.Errorf("bad member ID arg (%v), expecting ID in Hex", err) + } + + isDrained, err := etcd.Get(context.Background(), path.Join(NodeInfo, idHex, "witness")) + if err != nil { + return false, err + } + + return isDrained.Count != 0, nil +} + +func SetDrained(idHex string, on bool) error { + + isWitness, err := IsWitness(idHex) + if err != nil { + return err + } + + if isWitness { + return errors.New("cannot set drained on witness node, this node is not serving clients") + } + + _, err = strconv.ParseUint(idHex, 16, 64) if err != nil { return err } if on { - _, err = etcd.Put(context.Background(), path.Join(NodeEvents, idHex, "drain"), fmt.Sprintf("%t", on)) + _, err = etcd.Put(context.Background(), path.Join(NodeInfo, idHex, "drain"), fmt.Sprintf("%t", on)) return err } - _, err = etcd.Delete(context.Background(), path.Join(NodeEvents, idHex, "drain")) + _, err = etcd.Delete(context.Background(), path.Join(NodeInfo, idHex, "drain")) return err } @@ -115,7 +145,7 @@ func IsDrained(idHex string) (bool, error) { return false, fmt.Errorf("bad member ID arg (%v), expecting ID in Hex", err) } - isDrained, err := etcd.Get(context.Background(), path.Join(NodeEvents, idHex, "drain")) + isDrained, err := etcd.Get(context.Background(), path.Join(NodeInfo, idHex, "drain")) if err != nil { return false, err } @@ -234,7 +264,7 @@ func RemoveMember(idHex string) error { } // Clear any node metadata - _, err = etcd.Delete(context.Background(), path.Join(NodeEvents, idHex), clientv3.WithPrefix()) + _, err = etcd.Delete(context.Background(), path.Join(NodeInfo, idHex), clientv3.WithPrefix()) if err != nil { return err } diff --git a/internal/data/config.go b/internal/data/config.go index c8bd6276..27a3b4e5 100644 --- a/internal/data/config.go +++ b/internal/data/config.go @@ -8,7 +8,6 @@ import ( "net/url" "strings" - "github.com/NHAS/wag/internal/data/validators" "github.com/go-playground/validator/v10" clientv3 "go.etcd.io/etcd/client/v3" ) @@ -92,16 +91,6 @@ func getInt(key string) (ret int, err error) { return ret, nil } -func SetPAM(details PAM) error { - d, err := json.Marshal(details) - if err != nil { - return err - } - - _, err = etcd.Put(context.Background(), PamDetailsKey, string(d)) - return err -} - func GetPAM() (details PAM, err error) { response, err := etcd.Get(context.Background(), OidcDetailsKey) @@ -117,16 +106,6 @@ func GetPAM() (details PAM, err error) { return } -func SetOidc(details OIDC) error { - d, err := json.Marshal(details) - if err != nil { - return err - } - - _, err = etcd.Put(context.Background(), OidcDetailsKey, string(d)) - return err -} - func GetOidc() (details OIDC, err error) { response, err := etcd.Get(context.Background(), OidcDetailsKey) @@ -175,13 +154,6 @@ func GetWebauthn() (wba Webauthn, err error) { return } -func SetWireguardConfigName(wgConfig string) error { - data, _ := json.Marshal(wgConfig) - - _, err := etcd.Put(context.Background(), defaultWGFileNameKey, string(data)) - return err -} - func GetWireguardConfigName() string { k, err := getString(defaultWGFileNameKey) if err != nil { @@ -232,14 +204,6 @@ func GetAuthenicationMethods() (result []string, err error) { return } -func SetCheckUpdates(doChecks bool) error { - - data, _ := json.Marshal(doChecks) - - _, err := etcd.Put(context.Background(), checkUpdatesKey, string(data)) - return err -} - func ShouldCheckUpdates() (bool, error) { resp, err := etcd.Get(context.Background(), checkUpdatesKey) @@ -261,12 +225,6 @@ func ShouldCheckUpdates() (bool, error) { return ret, nil } -func SetDomain(domain string) error { - data, _ := json.Marshal(domain) - _, err := etcd.Put(context.Background(), DomainKey, string(data)) - return err -} - func GetDomain() (string, error) { return getString(DomainKey) } @@ -297,17 +255,6 @@ func GetHelpMail() string { return mail } -func SetExternalAddress(externalAddress string) error { - - if err := validators.ValidExternalAddresses(externalAddress); err != nil { - return err - } - - data, _ := json.Marshal(externalAddress) - _, err := etcd.Put(context.Background(), externalAddressKey, string(data)) - return err -} - func GetExternalAddress() (string, error) { return getString(externalAddressKey) } @@ -625,16 +572,6 @@ func GetSessionInactivityTimeoutMinutes() (int, error) { return inactivityTimeout, nil } -func SetLockout(accountLockout int) error { - if accountLockout < 1 { - return errors.New("cannot set lockout to be below 1 as all accounts would be locked out") - } - - data, _ := json.Marshal(accountLockout) - _, err := etcd.Put(context.Background(), LockoutKey, string(data)) - return err -} - // Get account lockout threshold setting func GetLockout() (int, error) { lockout, err := getInt(LockoutKey) diff --git a/internal/data/devices.go b/internal/data/devices.go index d86cec91..c585d188 100644 --- a/internal/data/devices.go +++ b/internal/data/devices.go @@ -5,6 +5,7 @@ import ( "encoding/json" "errors" "fmt" + "go.etcd.io/etcd/client/pkg/v3/types" "net" "time" @@ -23,6 +24,8 @@ type Device struct { Attempts int Active bool Authorised time.Time + + AssociatedNode types.ID } func (d Device) String() string { @@ -32,10 +35,13 @@ func (d Device) String() string { authorised = d.Authorised.Format(time.DateTime) } - return fmt.Sprintf("device[%s:%s][active: %t, attempts: %d, authorised: %s]", d.Username, d.Address, d.Active, d.Attempts, authorised) + return fmt.Sprintf("device[%s:%s:%s][active: %t, attempts: %d, authorised: %s]", d.Username, d.Address, d.AssociatedNode, d.Active, d.Attempts, authorised) } -func UpdateDeviceEndpoint(address string, endpoint *net.UDPAddr) error { +// UpdateDeviceConnectionDetails updates the endpoint we are receiving packets from and the associated cluster node +// I.e if data is coming in to node 3, all other nodes know that the session is only valid while connecting to node 3 +// this stops a race condition where an attacker uses a wireguard profile, but gets load balanced to another node member +func UpdateDeviceConnectionDetails(address string, endpoint *net.UDPAddr) error { realKey, err := etcd.Get(context.Background(), "deviceref-"+address) if err != nil { @@ -58,6 +64,7 @@ func UpdateDeviceEndpoint(address string, endpoint *net.UDPAddr) error { } device.Endpoint = endpoint + device.AssociatedNode = GetServerID() b, _ := json.Marshal(device) @@ -108,6 +115,7 @@ func AuthoriseDevice(username, address string) error { return "", errors.New("account is locked") } + device.AssociatedNode = GetServerID() device.Authorised = time.Now() device.Attempts = 0 @@ -187,6 +195,27 @@ func GetAllDevices() (devices []Device, err error) { return devices, nil } +func GetAllDevicesAsMap() (devices map[string]Device, err error) { + + devices = make(map[string]Device) + response, err := etcd.Get(context.Background(), "devices-", clientv3.WithPrefix(), clientv3.WithSort(clientv3.SortByKey, clientv3.SortDescend)) + if err != nil { + return nil, err + } + + for _, res := range response.Kvs { + var device Device + err := json.Unmarshal(res.Value, &device) + if err != nil { + return nil, err + } + + devices[device.Address] = device + } + + return devices, nil +} + func AddDevice(username, publickey string) (Device, error) { preshared_key, err := wgtypes.GenerateKey() @@ -292,7 +321,7 @@ func DeleteDevices(username string) error { return err } - ops := []clientv3.Op{} + var ops []clientv3.Op for _, reference := range deleted.PrevKvs { var d Device @@ -310,7 +339,7 @@ func DeleteDevices(username string) error { func UpdateDevicePublicKey(username, address string, publicKey wgtypes.Key) error { - beforeUpadte, err := GetDeviceByAddress(address) + beforeUpdate, err := GetDeviceByAddress(address) if err != nil { return err } @@ -337,7 +366,7 @@ func UpdateDevicePublicKey(username, address string, publicKey wgtypes.Key) erro return err } - _, err = etcd.Delete(context.Background(), "devicesref-"+beforeUpadte.Publickey) + _, err = etcd.Delete(context.Background(), "devicesref-"+beforeUpdate.Publickey) return err } diff --git a/internal/data/events.go b/internal/data/events.go index 4d158d89..895cdb88 100644 --- a/internal/data/events.go +++ b/internal/data/events.go @@ -43,7 +43,7 @@ const ( GroupsPrefix = "wag-groups-" ConfigPrefix = "wag-config-" AuthenticationPrefix = "wag-config-authentication-" - NodeEvents = "wag/node/" + NodeInfo = "wag/node/" NodeErrors = "wag/node/errors" ) @@ -147,18 +147,6 @@ func RegisterEventListener[T any](path string, isPrefix bool, f func(key string, return key, nil } -func DeregisterEventListener(key string) { - lck.Lock() - defer lck.Unlock() - - if cancel, ok := contextMaps[key]; ok { - if cancel != nil { - cancel() - } - delete(contextMaps, key) - } -} - func RegisterClusterHealthListener(f func(status string)) (string, error) { clusterHealthLck.Lock() defer clusterHealthLck.Unlock() @@ -173,13 +161,6 @@ func RegisterClusterHealthListener(f func(status string)) (string, error) { return key, nil } -func DeregisterClusterHealthListener(key string) { - clusterHealthLck.Lock() - defer clusterHealthLck.Unlock() - - delete(clusterHealthListeners, key) -} - func notifyClusterHealthListeners(event string) { clusterHealthLck.RLock() defer clusterHealthLck.RUnlock() @@ -227,7 +208,7 @@ func checkClusterHealth() { func testCluster() { ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) - _, err := etcd.Put(ctx, path.Join(NodeEvents, GetServerID(), "ping"), time.Now().Format(time.RFC1123Z)) + _, err := etcd.Put(ctx, path.Join(NodeInfo, GetServerID().String(), "ping"), time.Now().Format(time.RFC1123Z)) cancel() if err != nil { log.Println("unable to write liveness value") @@ -257,7 +238,7 @@ type EventError struct { func RaiseError(raisedError error, value []byte) (err error) { ee := EventError{ - NodeID: GetServerID(), + NodeID: GetServerID().String(), FailedEventData: string(value), Error: raisedError.Error(), Time: time.Now(), diff --git a/internal/data/init.go b/internal/data/init.go index 3cfaa12f..2a325b1e 100644 --- a/internal/data/init.go +++ b/internal/data/init.go @@ -24,10 +24,6 @@ import ( "go.etcd.io/etcd/server/v3/embed" ) -// TODO Most of the methods in this package need to change to prevent race conditions from breaking the cluster -// The way we're going to do this is each get method will return the etcd key revision. If a user tries to make a change using an older revision then that will be an error -// Or a prompt on the admin ui - var ( etcd *clientv3.Client etcdServer *embed.Etcd @@ -197,6 +193,12 @@ func Load(path, joinToken string, testing bool) error { } + if config.Values.Clustering.Witness { + + } else { + + } + go checkClusterHealth() return nil @@ -457,7 +459,7 @@ func TearDown() { func doSafeUpdate(ctx context.Context, key string, create bool, mutateFunc func(*clientv3.GetResponse) (value string, err error)) error { //https://github.com/kubernetes/kubernetes/blob/master/staging/src/k8s.io/apiserver/pkg/storage/etcd3/store.go#L382 - opts := []clientv3.OpOption{} + var opts []clientv3.OpOption if mutateFunc == nil { return errors.New("no mutate function set in safe update") diff --git a/internal/data/user.go b/internal/data/user.go index 5c182137..ca6c247b 100644 --- a/internal/data/user.go +++ b/internal/data/user.go @@ -24,7 +24,7 @@ func (um *UserModel) GetID() [20]byte { return sha1.Sum([]byte(um.Username)) } -// Make sure that the attempts is always incremented first to stop race condition attacks +// IncrementAuthenticationAttempt Make sure that the attempts is always incremented first to stop race condition attacks func IncrementAuthenticationAttempt(username, device string) error { return doSafeUpdate(context.Background(), deviceKey(username, device), false, func(gr *clientv3.GetResponse) (value string, err error) { diff --git a/internal/router/bpf.go b/internal/router/bpf.go index e6053428..1cf63f7c 100644 --- a/internal/router/bpf.go +++ b/internal/router/bpf.go @@ -28,7 +28,6 @@ import ( //go:generate go run github.com/cilium/ebpf/cmd/bpf2go -cc $BPF_CLANG -cflags $BPF_CFLAGS bpf xdp.c -- -I headers const ( - ebpfFS = "/sys/fs/bpf" CLOCK_BOOTTIME = uint32(7) ) @@ -102,6 +101,11 @@ func loadXDP() error { return fmt.Errorf("loading objects: %s", err) } + err = xdpObjects.NodeId.Put(uint32(0), uint64(data.GetServerID())) + if err != nil { + return fmt.Errorf("could not set node id: %s", err) + } + sessionInactivityTimeoutMinutes, err := data.GetSessionInactivityTimeoutMinutes() if err != nil { return err @@ -164,7 +168,7 @@ func setupXDP(users []data.UserModel, knownDevices []data.Device) error { for _, device := range knownDevices { - err := xdpAddDevice(device.Username, device.Address) + err := xdpAddDevice(device.Username, device.Address, uint64(device.AssociatedNode)) if err != nil { return errors.New("xdp setup add device to user: " + err.Error()) } @@ -261,7 +265,7 @@ func xdpRemoveDevice(address string) error { return finalError } -func xdpAddDevice(username, address string) error { +func xdpAddDevice(username, address string, associatedNode uint64) error { ip := net.ParseIP(address) if ip == nil { @@ -279,6 +283,7 @@ func xdpAddDevice(username, address string) error { deviceStruct.lastPacketTime = 0 deviceStruct.sessionExpiry = 0 deviceStruct.user_id = sha1.Sum([]byte(username)) + deviceStruct.associatedNode = associatedNode if err := xdpUserExists(deviceStruct.user_id); err != nil { return err @@ -417,7 +422,7 @@ func clearPolicyMap(toClear *ebpf.Map) error { } err = toClear.Delete(lastKey) - if err != nil && err != ebpf.ErrKeyNotExist { + if err != nil && errors.Is(err, ebpf.ErrKeyNotExist) { return err } @@ -594,7 +599,7 @@ func RefreshUserAcls(username string) error { } // SetAuthroized correctly sets the timestamps for a device with internal IP address as internalAddress -func SetAuthorized(internalAddress, username string) error { +func SetAuthorized(internalAddress, username string, node uint64) error { if net.ParseIP(internalAddress).To4() == nil { return errors.New("internalAddress could not be parsed as an IPv4 address") @@ -605,6 +610,7 @@ func SetAuthorized(internalAddress, username string) error { var deviceStruct fwentry deviceStruct.lastPacketTime = GetTimeStamp() + deviceStruct.associatedNode = node maxSession, err := data.GetSessionLifetimeMinutes() if err != nil { @@ -621,6 +627,36 @@ func SetAuthorized(internalAddress, username string) error { return xdpObjects.Devices.Update(net.ParseIP(internalAddress).To4(), deviceStruct.Bytes(), ebpf.UpdateExist) } +func UpdateNodeAssociation(device data.Device) error { + lock.Lock() + defer lock.Unlock() + + // If the peer roams away from us, unset the endpoint in wireguard to make sure the peer watcher will absolutely register a change if they roam back + if device.AssociatedNode != data.GetServerID() { + err := setPeerEndpoint(device, nil) + if err != nil { + return err + } + } + + ip := net.ParseIP(device.Address) + + deviceBytes, err := xdpObjects.Devices.LookupBytes(ip.To4()) + if err != nil { + return err + } + + var deviceStruct fwentry + err = deviceStruct.Unpack(deviceBytes) + if err != nil { + return err + } + + deviceStruct.associatedNode = uint64(device.AssociatedNode) + + return xdpObjects.Devices.Update(ip.To4(), deviceStruct.Bytes(), ebpf.UpdateExist) +} + func Deauthenticate(address string) error { lock.Lock() @@ -681,6 +717,7 @@ type fwDevice struct { Expiry uint64 IP string Authorized bool + AssociatedNode string } func GetRoutes(username string) ([]string, error) { @@ -796,7 +833,13 @@ func GetRules() (map[string]FirewallRules, error) { } fwRule := result[res] - fwRule.Devices = append(fwRule.Devices, fwDevice{IP: net.IP(ipBytes).String(), Authorized: isAuthed(net.IP(ipBytes).String()), Expiry: deviceStruct.sessionExpiry, LastPacketTimestamp: deviceStruct.lastPacketTime}) + fwRule.Devices = append(fwRule.Devices, fwDevice{ + IP: net.IP(ipBytes).String(), + Authorized: isAuthed(net.IP(ipBytes).String()), + Expiry: deviceStruct.sessionExpiry, + LastPacketTimestamp: deviceStruct.lastPacketTime, + AssociatedNode: fmt.Sprintf("%x (%d)", deviceStruct.associatedNode, deviceStruct.associatedNode), + }) if err := xdpObjects.AccountLocked.Lookup(deviceStruct.user_id, &fwRule.AccountLocked); err != nil { log.Println("[ERROR] User ID was not properly in firewall map: ", hex.EncodeToString(deviceStruct.user_id[:]), " err: ", err) diff --git a/internal/router/bpf_bpfeb.go b/internal/router/bpf_bpfeb.go index 81946d39..e3c37716 100644 --- a/internal/router/bpf_bpfeb.go +++ b/internal/router/bpf_bpfeb.go @@ -63,6 +63,7 @@ type bpfMapSpecs struct { AccountLocked *ebpf.MapSpec `ebpf:"account_locked"` Devices *ebpf.MapSpec `ebpf:"devices"` InactivityTimeoutMinutes *ebpf.MapSpec `ebpf:"inactivity_timeout_minutes"` + NodeId *ebpf.MapSpec `ebpf:"node_Id"` PoliciesTable *ebpf.MapSpec `ebpf:"policies_table"` } @@ -88,6 +89,7 @@ type bpfMaps struct { AccountLocked *ebpf.Map `ebpf:"account_locked"` Devices *ebpf.Map `ebpf:"devices"` InactivityTimeoutMinutes *ebpf.Map `ebpf:"inactivity_timeout_minutes"` + NodeId *ebpf.Map `ebpf:"node_Id"` PoliciesTable *ebpf.Map `ebpf:"policies_table"` } @@ -96,6 +98,7 @@ func (m *bpfMaps) Close() error { m.AccountLocked, m.Devices, m.InactivityTimeoutMinutes, + m.NodeId, m.PoliciesTable, ) } diff --git a/internal/router/bpf_bpfeb.o b/internal/router/bpf_bpfeb.o index d1489286..4a7b25f5 100644 Binary files a/internal/router/bpf_bpfeb.o and b/internal/router/bpf_bpfeb.o differ diff --git a/internal/router/bpf_bpfel.go b/internal/router/bpf_bpfel.go index 083867a1..04f4c1b5 100644 --- a/internal/router/bpf_bpfel.go +++ b/internal/router/bpf_bpfel.go @@ -63,6 +63,7 @@ type bpfMapSpecs struct { AccountLocked *ebpf.MapSpec `ebpf:"account_locked"` Devices *ebpf.MapSpec `ebpf:"devices"` InactivityTimeoutMinutes *ebpf.MapSpec `ebpf:"inactivity_timeout_minutes"` + NodeId *ebpf.MapSpec `ebpf:"node_Id"` PoliciesTable *ebpf.MapSpec `ebpf:"policies_table"` } @@ -88,6 +89,7 @@ type bpfMaps struct { AccountLocked *ebpf.Map `ebpf:"account_locked"` Devices *ebpf.Map `ebpf:"devices"` InactivityTimeoutMinutes *ebpf.Map `ebpf:"inactivity_timeout_minutes"` + NodeId *ebpf.Map `ebpf:"node_Id"` PoliciesTable *ebpf.Map `ebpf:"policies_table"` } @@ -96,6 +98,7 @@ func (m *bpfMaps) Close() error { m.AccountLocked, m.Devices, m.InactivityTimeoutMinutes, + m.NodeId, m.PoliciesTable, ) } diff --git a/internal/router/bpf_bpfel.o b/internal/router/bpf_bpfel.o index 3d253d79..75ce5bcb 100644 Binary files a/internal/router/bpf_bpfel.o and b/internal/router/bpf_bpfel.o differ diff --git a/internal/router/ebpf_test.go b/internal/router/ebpf_test.go index be6c82b8..3274bbf3 100644 --- a/internal/router/ebpf_test.go +++ b/internal/router/ebpf_test.go @@ -63,7 +63,7 @@ func TestBlankPacket(t *testing.T) { func TestAddNewDevices(t *testing.T) { var ipBytes []byte - var deviceBytes = make([]byte, 40) + var deviceBytes = make([]byte, 48) found := map[string]bool{} @@ -212,7 +212,7 @@ func TestRoutePriority(t *testing.T) { func TestBasicAuthorise(t *testing.T) { - err := SetAuthorized(devices["tester"].Address, devices["tester"].Username) + err := SetAuthorized(devices["tester"].Address, devices["tester"].Username, uint64(data.GetServerID())) if err != nil { t.Fatal(err) } @@ -426,7 +426,7 @@ func TestRoutePreference(t *testing.T) { func TestSlidingWindow(t *testing.T) { - err := SetAuthorized(devices["tester"].Address, devices["tester"].Username) + err := SetAuthorized(devices["tester"].Address, devices["tester"].Username, uint64(data.GetServerID())) if err != nil { t.Fatal(err) } @@ -529,7 +529,7 @@ func TestSlidingWindow(t *testing.T) { func TestCompositeRules(t *testing.T) { - err := SetAuthorized(devices["tester"].Address, devices["tester"].Username) + err := SetAuthorized(devices["tester"].Address, devices["tester"].Username, uint64(data.GetServerID())) if err != nil { t.Fatal(err) } @@ -632,7 +632,7 @@ func TestDisabledSlidingWindow(t *testing.T) { t.Fatalf("the inactivity timeout was not set to max uint64, was %d (maxuint64 %d)", timeoutFromMap, uint64(math.MaxUint64)) } - err = SetAuthorized(devices["tester"].Address, devices["tester"].Username) + err = SetAuthorized(devices["tester"].Address, devices["tester"].Username, uint64(data.GetServerID())) if err != nil { t.Fatal(err) } @@ -688,7 +688,7 @@ func TestDisabledSlidingWindow(t *testing.T) { func TestMaxSessionLifetime(t *testing.T) { - err := SetAuthorized(devices["tester"].Address, devices["tester"].Username) + err := SetAuthorized(devices["tester"].Address, devices["tester"].Username, uint64(data.GetServerID())) if err != nil { t.Fatal(err) } @@ -753,7 +753,7 @@ func TestDisablingMaxLifetime(t *testing.T) { t.Fatal(err) } - err = SetAuthorized(devices["tester"].Address, devices["tester"].Username) + err = SetAuthorized(devices["tester"].Address, devices["tester"].Username, uint64(data.GetServerID())) if err != nil { t.Fatal(err) } @@ -956,7 +956,6 @@ func TestAgnosticRuleOrdering(t *testing.T) { for _, user := range devices { acl := data.GetEffectiveAcl(user.Username) - log.Println(user, acl.Allow) rules, err := routetypes.ParseRules(nil, acl.Allow, nil) if err != nil { t.Fatal(err) @@ -1159,7 +1158,7 @@ func addDevices() error { return err } - err = xdpAddDevice(device.Username, device.Address) + err = xdpAddDevice(device.Username, device.Address, uint64(data.GetServerID())) if err != nil { return err } diff --git a/internal/router/fwentry.go b/internal/router/fwentry.go index d89a6827..03de54e3 100644 --- a/internal/router/fwentry.go +++ b/internal/router/fwentry.go @@ -15,15 +15,17 @@ type fwentry struct { user_id [20]byte pad uint32 + + associatedNode uint64 } func (d fwentry) Size() int { - return 40 // 8 + 8 + 20 + 4 + return 48 // 8 + 8 + 20 + 4 + 8 } func (d fwentry) Bytes() []byte { - output := make([]byte, 40) + output := make([]byte, 48) binary.LittleEndian.PutUint64(output[0:8], d.sessionExpiry) binary.LittleEndian.PutUint64(output[8:16], d.lastPacketTime) @@ -31,12 +33,13 @@ func (d fwentry) Bytes() []byte { copy(output[16:36], d.user_id[:]) binary.LittleEndian.PutUint32(output[36:], d.pad) + binary.LittleEndian.PutUint64(output[40:], d.associatedNode) return output } func (d *fwentry) Unpack(b []byte) error { - if len(b) != 40 { + if len(b) != 48 { return errors.New("firewall entry is too short") } @@ -46,6 +49,7 @@ func (d *fwentry) Unpack(b []byte) error { copy(d.user_id[:], b[16:36]) d.pad = binary.LittleEndian.Uint32(b[36:]) + d.associatedNode = binary.LittleEndian.Uint64(b[40:]) return nil } diff --git a/internal/router/init.go b/internal/router/init.go index 17382350..30088ffc 100644 --- a/internal/router/init.go +++ b/internal/router/init.go @@ -53,18 +53,7 @@ func Setup(errorChan chan<- error, iptables bool) (err error) { handleEvents(errorChan) go func() { - startup := true - cache := map[string]string{} - d, err := data.GetAllDevices() - if err != nil { - errorChan <- err - return - } - - for _, device := range d { - cache[device.Address] = device.Endpoint.String() - } - + ourPeerAddresses := make(map[string]string) for { select { @@ -77,6 +66,11 @@ func Setup(errorChan chan<- error, iptables bool) (err error) { return } + devices, err := data.GetAllDevicesAsMap() + if err != nil { + errorChan <- fmt.Errorf("endpoint watcher: failed to retrieve devices from etcd: %s", err) + return + } for _, p := range dev.Peers { if len(p.AllowedIPs) != 1 { @@ -84,37 +78,42 @@ func Setup(errorChan chan<- error, iptables bool) (err error) { continue } - ip := p.AllowedIPs[0].IP.String() + device, ok := devices[p.AllowedIPs[0].IP.String()] + if !ok { + log.Println("found unknown device,", p.AllowedIPs[0].IP.String()) + continue + } + + // If the peer endpoint has become empty (due to peer roaming) or if we dont have a record of it, set the map + if _, ok := ourPeerAddresses[device.Address]; !ok || p.Endpoint == nil { + ourPeerAddresses[device.Address] = p.Endpoint.String() + } - if cache[ip] != p.Endpoint.String() { - cache[ip] = p.Endpoint.String() + // If the peer address has changed, but is not empty (empty indicates the peer has changed it node association away from this node) + if ourPeerAddresses[device.Address] != p.Endpoint.String() && ourPeerAddresses[device.Address] != "" { + ourPeerAddresses[device.Address] = p.Endpoint.String() - d, err := data.GetDeviceByAddress(ip) - if err != nil { - log.Println("unable to get previous device endpoint for", ip, "err:", err) - if err := Deauthenticate(ip); err != nil { - log.Println(ip, "unable to remove forwards for device:", err) + // If we register an endpoint change on our real world device, and the Endpoint is not the same as what the cluster knows + // i.e the peer has either roamed and its egress has changed, or it's an attacker using a stolen wireguard profile + // Deauthenticate it + if device.Endpoint.String() != p.Endpoint.String() { + log.Printf("%s:%s endpoint changed %s -> %s", device.Address, device.Username, device.Endpoint.String(), p.Endpoint.String()) + + err = data.DeauthenticateDevice(device.Address) + if err != nil { + log.Printf("failed to deauth device (%s:%s) endpoint: %s", device.Address, device.Username, err) } - continue } - err = data.UpdateDeviceEndpoint(p.AllowedIPs[0].IP.String(), p.Endpoint) + // Otherwise, just update the node association + err = data.UpdateDeviceConnectionDetails(p.AllowedIPs[0].IP.String(), p.Endpoint) if err != nil { - log.Println(ip, "unable to update device endpoint: ", err) + log.Printf("unable to update device (%s:%s) endpoint: %s", device.Address, device.Username, err) } - //Dont try and remove rules, if we've just started - if !startup { - log.Println(ip, "endpoint changed", d.Endpoint.String(), "->", p.Endpoint.String()) - if err := Deauthenticate(ip); err != nil { - log.Println(ip, "unable to remove forwards for device: ", err) - } - } } } - - startup = false } } diff --git a/internal/router/iptables.go b/internal/router/iptables.go index 4fe7d2ed..53374bce 100644 --- a/internal/router/iptables.go +++ b/internal/router/iptables.go @@ -18,7 +18,7 @@ func setupIptables() error { devName := config.Values.Wireguard.DevName //So. This to the average person will look like we say "Hey server forward anything and everything from the wireguard interface" - //And without the xdp ebpf program it would be, however if you look at xdp.c you can see that we can manipluate maps of addresses for each user + //And without the xdp ebpf program it would be, however if you look at xdp.c you can see that we can manipulate maps of addresses for each user //This then controls whether the packet is dropped, but we still need iptables to do the higher level routing stuffs err = ipt.ChangePolicy("filter", "FORWARD", "DROP") diff --git a/internal/router/statemachine.go b/internal/router/statemachine.go index ff63d635..d7687913 100644 --- a/internal/router/statemachine.go +++ b/internal/router/statemachine.go @@ -11,47 +11,47 @@ import ( "golang.zx2c4.com/wireguard/wgctrl/wgtypes" ) -func handleEvents(erroChan chan<- error) { +func handleEvents(errorChan chan<- error) { _, err := data.RegisterEventListener(data.DevicesPrefix, true, deviceChanges) if err != nil { - erroChan <- err + errorChan <- err return } _, err = data.RegisterEventListener(data.GroupMembershipPrefix, true, membershipChanges) if err != nil { - erroChan <- err + errorChan <- err return } _, err = data.RegisterEventListener(data.UsersPrefix, true, userChanges) if err != nil { - erroChan <- err + errorChan <- err return } _, err = data.RegisterEventListener(data.AclsPrefix, true, aclsChanges) if err != nil { - erroChan <- err + errorChan <- err return } _, err = data.RegisterEventListener(data.GroupsPrefix, true, groupChanges) if err != nil { - erroChan <- err + errorChan <- err return } _, err = data.RegisterEventListener(data.InactivityTimeoutKey, true, inactivityTimeoutChanges) if err != nil { - erroChan <- err + errorChan <- err return } } -func inactivityTimeoutChanges(key string, current, previous int, et data.EventType) error { +func inactivityTimeoutChanges(_ string, current, _ int, et data.EventType) error { switch et { case data.MODIFIED, data.CREATED: @@ -64,7 +64,7 @@ func inactivityTimeoutChanges(key string, current, previous int, et data.EventTy return nil } -func deviceChanges(key string, current, previous data.Device, et data.EventType) error { +func deviceChanges(_ string, current, previous data.Device, et data.EventType) error { switch et { case data.DELETED: @@ -77,7 +77,7 @@ func deviceChanges(key string, current, previous data.Device, et data.EventType) case data.CREATED: key, _ := wgtypes.ParseKey(current.Publickey) - err := AddPeer(key, current.Username, current.Address, current.PresharedKey) + err := AddPeer(key, current.Username, current.Address, current.PresharedKey, uint64(current.AssociatedNode)) if err != nil { return fmt.Errorf("unable to create peer: %s: err: %s", current.Address, err) } @@ -99,7 +99,7 @@ func deviceChanges(key string, current, previous data.Device, et data.EventType) return fmt.Errorf("cannot get lockout: %s", err) } - if (current.Attempts != previous.Attempts && current.Attempts > lockout) || // If the number of authentication attempts on a device has exceeded the max + if current.Attempts > lockout || // If the number of authentication attempts on a device has exceeded the max current.Endpoint.String() != previous.Endpoint.String() || // If the client ip has changed current.Authorised.IsZero() { // If we've explicitly deauthorised a device err := Deauthenticate(current.Address) @@ -110,9 +110,19 @@ func deviceChanges(key string, current, previous data.Device, et data.EventType) } - if current.Authorised != previous.Authorised { - if !current.Authorised.IsZero() && current.Attempts <= lockout { - err := SetAuthorized(current.Address, current.Username) + if current.AssociatedNode != previous.AssociatedNode { + err := UpdateNodeAssociation(current) + if err != nil { + return fmt.Errorf("cannot change device node association %s:%s: %s", current.Address, current.Username, err) + } + + log.Printf("changed device (%s:%s) node association: %s -> %s", current.Address, current.Username, previous.AssociatedNode, current.AssociatedNode) + } + + // If the authorisation state has changed and is not disabled + if current.Authorised != previous.Authorised && !current.Authorised.IsZero() { + if current.Attempts <= lockout && current.AssociatedNode == previous.AssociatedNode { + err := SetAuthorized(current.Address, current.Username, uint64(current.AssociatedNode)) if err != nil { return fmt.Errorf("cannot authorize device %s: %s", current.Address, err) } @@ -127,7 +137,7 @@ func deviceChanges(key string, current, previous data.Device, et data.EventType) return nil } -func membershipChanges(key string, current, previous []string, et data.EventType) error { +func membershipChanges(key string, _, _ []string, et data.EventType) error { username := strings.TrimPrefix(key, data.GroupMembershipPrefix) switch et { @@ -142,11 +152,11 @@ func membershipChanges(key string, current, previous []string, et data.EventType return nil } -func userChanges(key string, current, previous data.UserModel, et data.EventType) error { +func userChanges(_ string, current, previous data.UserModel, et data.EventType) error { switch et { case data.CREATED: - acls := data.GetEffectiveAcl(current.Username) - err := AddUser(current.Username, acls) + newUserAcls := data.GetEffectiveAcl(current.Username) + err := AddUser(current.Username, newUserAcls) if err != nil { log.Printf("cannot create user %s: %s", current.Username, err) return fmt.Errorf("cannot create user %s: %s", current.Username, err) @@ -187,7 +197,8 @@ func userChanges(key string, current, previous data.UserModel, et data.EventType return nil } -func aclsChanges(key string, current, previous acls.Acl, et data.EventType) error { +func aclsChanges(_ string, _, _ acls.Acl, et data.EventType) error { + // TODO refresh the users that the acl applies to as a potential performance improvement switch et { case data.CREATED, data.DELETED, data.MODIFIED: err := RefreshConfiguration() @@ -200,7 +211,7 @@ func aclsChanges(key string, current, previous acls.Acl, et data.EventType) erro return nil } -func groupChanges(key string, current, previous []string, et data.EventType) error { +func groupChanges(_ string, current, _ []string, et data.EventType) error { switch et { case data.CREATED, data.DELETED, data.MODIFIED: diff --git a/internal/router/wireguard.go b/internal/router/wireguard.go index 4e58da70..5a5b5b1e 100644 --- a/internal/router/wireguard.go +++ b/internal/router/wireguard.go @@ -96,10 +96,13 @@ func setupWireguard(devices []data.Device) error { PublicKey: pk, ReplaceAllowedIPs: true, AllowedIPs: []net.IPNet{*network}, - Endpoint: device.Endpoint, PresharedKey: psk, } + if device.AssociatedNode == data.GetServerID() { + pc.Endpoint = device.Endpoint + } + if config.Values.Wireguard.ServerPersistentKeepAlive > 0 { d := time.Duration(config.Values.Wireguard.ServerPersistentKeepAlive) * time.Second pc.PersistentKeepaliveInterval = &d @@ -236,6 +239,30 @@ func ReplacePeer(device data.Device, newPublicKey wgtypes.Key) error { return nil } +func setPeerEndpoint(device data.Device, endpoint *net.UDPAddr) error { + + id, err := wgtypes.ParseKey(device.Publickey) + if err != nil { + return err + } + + var c wgtypes.Config + c.Peers = []wgtypes.PeerConfig{ + { + UpdateOnly: true, + PublicKey: id, + Endpoint: endpoint, + }, + } + + err = ctrl.ConfigureDevice(config.Values.Wireguard.DevName, c) + if err != nil { + return err + } + + return nil +} + func ListPeers() ([]wgtypes.Peer, error) { lock.Lock() @@ -250,7 +277,7 @@ func ListPeers() ([]wgtypes.Peer, error) { } // AddPeer adds the device to wireguard -func AddPeer(public wgtypes.Key, username, addresss, presharedKey string) (err error) { +func AddPeer(public wgtypes.Key, username, addresss, presharedKey string, node uint64) (err error) { lock.Lock() defer lock.Unlock() @@ -275,9 +302,8 @@ func AddPeer(public wgtypes.Key, username, addresss, presharedKey string) (err e }, } - err = xdpAddDevice(username, addresss) + err = xdpAddDevice(username, addresss, node) if err != nil { - return err } @@ -298,21 +324,6 @@ func AddPeer(public wgtypes.Key, username, addresss, presharedKey string) (err e return nil } -func GetPeerRealIp(address string) (string, error) { - dev, err := ctrl.Device(config.Values.Wireguard.DevName) - if err != nil { - return "", err - } - - for _, peer := range dev.Peers { - if len(peer.AllowedIPs) == 1 && peer.AllowedIPs[0].IP.String() == address { - return peer.Endpoint.String(), nil - } - } - - return "", errors.New("not found") -} - func addWg(c *netlink.Conn, name string, address net.IPNet, mtu int) error { infomsg := IfInfomsg{ diff --git a/internal/router/xdp.c b/internal/router/xdp.c index e2e1faf1..06800019 100644 --- a/internal/router/xdp.c +++ b/internal/router/xdp.c @@ -14,34 +14,33 @@ char __license[] SEC("license") = "Dual MIT/GPL"; /* A massive oversimplifcation of what is in this file. - - ┌───────────────────────────────┐ ┌───────────────────────────────────┐ - │ Inactivity Timeout │ │ Devices │ - │ │ │ map │ - │ uint64 (minutes) │ │ key: ipv4 (u32) │ - │ │ │ val: sizeof(struct device) │ - └───────────────────────────────┘ │ │ │ - └─────────────────┼─────────────────┘ - │ - ┌─────────────▼──────────────┐ - │ device struct │ - ┌─────────────────────────────────────┐ │ │ - │ User │◄─────────────┼─ userid char[20] │ - │ │ │ sessionExpiry uint64 │ - ├─────────────────────────────────────┤ │ lastPacketTime uint64 │ - │ AccountLocked │ │ deviceLock uint32 │ - │ uint32 │ └────────────────────────────┘ - ├─────────────────────────────────────┤ - │ Public Routes LPM │ - │ key uint32 │ ┌─────────────────────────────┐ - │ value policies[128]─────────┼───────┐ │ policy struct │ - │ │ │ │ policy_type uint16 │ - ├─────────────────────────────────────┤ ├────►│ lower_port uint16 │ - │ MFA Routes LPM │ │ │ upper_port uint16 │ - │ key uint32 │ │ │ proto uint16 │ - │ value policies[128] ────────┼───────┘ │ │ - │ │ └─────────────────────────────┘ - └─────────────────────────────────────┘ + ┌──────────────────────────────┐ + │ Devices │ + │ map │ + │ key: ipv4 (u32) │ + │ val: sizeof(struct device) │ + │ │ │ + ┌────────────────────────────────────────────┐ └──────────────┼───────────────┘ + │ Policies │ │ + │ │ ┌─────────────▼──────────────┐ + │ map │ │ device struct │ + │ │ │ │ + │ key: userid char[20] ◄────────────────────┼─────────────────┼─ userid char[20] │ + │ val: Max Polices * sizeof(struct device) │ │ sessionExpiry uint64 │ + │ │ │ │ lastPacketTime uint64 │ + └──────────────────────┼─────────────────────┘ │ associatedNode uint64 │ + │ │ │ + │ └────────────────────────────┘ + Max Policies + │ + ┌──────────────▼──────────────┐ + │ policy struct │ ┌────────────────────┐ ┌─────────────────────────┐ + │ policy_type uint16 │ │ Inactivity Timeout │ │ Associated Node │ + │ lower_port uint16 │ │ │ │ │ + │ upper_port uint16 │ │ Array │ │ Array │ + │ proto uint16 │ │ │ │ │ + │ │ │ uint64 (minutes) │ │ uint64 (etcd node id) │ + └─────────────────────────────┘ └────────────────────┘ └─────────────────────────┘ ┌─────────────────────────────────────────────────────────────────────────────────────────────────────────┐ │ Packet Flow │ @@ -81,41 +80,60 @@ A massive oversimplifcation of what is in this file. │ │ │ │ │ │ │ ┌───────────▼─────────────┐ │ -│ │ Lookup User │ │ +│ │ │ │ │ │ │ ┌────────┐ │ -│ │ user = users( │ not_found(userid) │ │ │ -│ │ userid │ ───────────────────────────────────────────────► DROP │ │ -│ │ ) │ │ │ │ -│ │ │ └────────┘ │ -│ └───────────┬─────────────┘ │ +│ │ Check User Exists │ not_found(userid) │ │ │ +│ │ │ ───────────────────────────────────────────────► DROP │ │ +│ │ │ │ │ │ +│ └───────────┬─────────────┘ └────────┘ │ +│ │ │ +│ │ │ +│ node_id : uint64 │ │ │ │ │ +│ ▼ │ +│ ┌──────────────────────┐ │ +│ │ │ │ +│ │ Check │ device not associated with current ┌────────┐ │ +│ │ │ node │ │ │ +│ │ Node ID │ ───────────────────────────────────────────────► DROP │ │ +│ │ = │ │ │ │ +│ │ Peer Associated ID │ └────────┘ │ +│ │ │ │ +│ └──────────┬───────────┘ │ │ │ │ -│ device.LastPacketTime : u64 │ │ -│ dst_ip : u32 │ │ │ │ │ │ │ │ │ │ │ -│ ┌────────────▼────────────┐ ┌────────┐ │ -│ │ Routes │ matches neither public or mfa routes │ │ │ -│ │ Check │ ────────────────────────────────────────────────► DROP │ │ -│ └────────────┬────────────┘ │ │ │ -│ │ └────────┘ │ │ │ │ -│ policies struct policy[128] │ │ +│ ▼ │ +│ ┌──────────────────────┐ │ +│ │ │ │ +│ │ │ │ +│ │ │ │ +│ │ Check Policies │ │ +│ │ │◄─────────────────────┐ │ +│ │ │ │ │ +│ │ │ │ │ +│ └──────────┬──────┬────┘ Check all policies │ │ +│ │ │ 128 MAX │ │ +│ │ │ │ │ +│ │ └───────────────────────────┘ │ │ │ │ │ │ │ -│ ┌────────┴─────────┐ │ -│ │ │ ┌────────┐ │ -│ │ Check Policies │ no_match(polices,port,proto) │ │ │ -│ │ │ ───────────────────────────────────────────────► DROP │ │ -│ └────────┬─────────┘ │ │ │ -│ │ └────────┘ │ │ │ │ -│ ┌────▼────┐ │ -│ │ │ │ -│ │ PASS │ │ -│ │ │ │ -│ └─────────┘ │ +│ │ If user policies allow access to dst ip │ +│ │ and │ +│ │ (either user is authorized │ +│ │ or │ +│ │ the route is public/always allowed) │ +│ │ │ +│ │ │ +│ ┌────▼───┐ │ +│ │ │ │ +│ │ PASS │ │ +│ │ │ │ +│ └────────┘ │ +│ │ │ │ └─────────────────────────────────────────────────────────────────────────────────────────────────────────┘ */ @@ -273,6 +291,17 @@ struct icmphdr } un; }; +struct ip +{ + __u32 src_ip; + __u16 src_port; + + __u32 dst_ip; + __u16 dst_port; + + __u32 proto; +}; + struct device { __u64 sessionExpiry; @@ -284,18 +313,10 @@ struct device __u32 PAD; -} __attribute__((__packed__)); - -struct ip -{ - __u32 src_ip; - __u16 src_port; + __u64 associatedNode; - __u32 dst_ip; - __u16 dst_port; +} __attribute__((__packed__)); - __u32 proto; -}; struct bpf_map_def SEC("maps") devices = { .type = BPF_MAP_TYPE_HASH, @@ -352,6 +373,15 @@ struct bpf_map_def SEC("maps") inactivity_timeout_minutes = { .map_flags = 0, }; +// A single variable that contains the node ID +struct bpf_map_def SEC("maps") node_Id = { + .type = BPF_MAP_TYPE_ARRAY, + .max_entries = 1, + .key_size = sizeof(__u32), + .value_size = sizeof(__u64), + .map_flags = 0, +}; + /* Attempt to parse the IPv4 source address from the packet. Returns 0 if there is no IPv4 header field; otherwise returns non-zero. @@ -479,8 +509,23 @@ static __always_inline int conntrack(struct ip *ip_info) return 0; } - // // Our userland defined inactivity timeout + + // General index used to get things out of the map arrays __u32 index = 0; + + __u64 *current_node_id = bpf_map_lookup_elem(&node_Id, &index); + if (current_node_id == NULL) + { + return 0; + } + + // If the traffic comes from a peer that we are not associated with, i.e traffic is coming to a node who has not talked to this peer before + // kill it + if(*current_node_id != current_device->associatedNode) { + return 0; + } + + __u64 *inactivity_timeout = bpf_map_lookup_elem(&inactivity_timeout_minutes, &index); if (inactivity_timeout == NULL) { diff --git a/internal/routetypes/key.go b/internal/routetypes/key.go index 881ff1a7..8f432b12 100644 --- a/internal/routetypes/key.go +++ b/internal/routetypes/key.go @@ -10,7 +10,7 @@ import ( type Key struct { // first member must be a prefix u32 wide - // rest can are arbitrary + // rest can be arbitrary Prefixlen uint32 IP [4]byte } diff --git a/internal/webserver/authenticators/authenticators.go b/internal/webserver/authenticators/authenticators.go index 260d44d7..522a31b3 100644 --- a/internal/webserver/authenticators/authenticators.go +++ b/internal/webserver/authenticators/authenticators.go @@ -58,7 +58,7 @@ func ReinitaliseMethods(method ...types.MFA) ([]types.MFA, error) { lck.Lock() defer lck.Unlock() - out := []types.MFA{} + var out []types.MFA var errRet error for _, m := range method { @@ -95,7 +95,7 @@ func GetAllEnabledMethods() (r []Authenticator) { lck.RLock() defer lck.RUnlock() - order := []string{} + var order []string for k := range allMfa { order = append(order, string(k)) } @@ -115,7 +115,7 @@ func GetAllAvaliableMethods() (r []Authenticator) { lck.RLock() defer lck.RUnlock() - order := []string{} + var order []string for k := range allMfa { order = append(order, string(k)) } @@ -175,22 +175,22 @@ type Authenticator interface { Type() string - // Name that is displayed in the MFA selection table + //FriendlyName is the name that is displayed in the MFA selection table FriendlyName() string - // Redirection path that deauthenticates selected mfa method (mostly just "/" unless its externally connected to something) + //LogoutPath returns the redirection path that deauthenticates selected mfa method (mostly just "/" unless it's externally connected to something) LogoutPath() string - // Automatically added under /register_mfa/ + //RegistrationAPI automatically added under /register_mfa/ RegistrationAPI(w http.ResponseWriter, r *http.Request) - // Automatically added under /authorise/ + //AuthorisationAPI automatically added under /authorise/ AuthorisationAPI(w http.ResponseWriter, r *http.Request) - // Executed in /authorise/ path to display UI when user browses to that path + //MFAPromptUI is executed in /authorise/ path to display UI when user browses to that path MFAPromptUI(w http.ResponseWriter, r *http.Request, username, ip string) - // Executed in /register_mfa/ path to show the UI for registration + //RegistrationUI is executed in /register_mfa/ path to show the UI for registration RegistrationUI(w http.ResponseWriter, r *http.Request, username, ip string) } diff --git a/internal/webserver/authenticators/oidc.go b/internal/webserver/authenticators/oidc.go index b31bf2fb..f0050f86 100644 --- a/internal/webserver/authenticators/oidc.go +++ b/internal/webserver/authenticators/oidc.go @@ -178,7 +178,7 @@ func (o *Oidc) AuthorisationAPI(w http.ResponseWriter, r *http.Request) { } // Rather ugly way of converting []interface{} into []string{} - groups := []string{} + var groups []string for i := range groupsIntf { conv, ok := groupsIntf[i].(string) if !ok { @@ -245,10 +245,10 @@ func (o *Oidc) AuthorisationAPI(w http.ResponseWriter, r *http.Request) { rp.CodeExchangeHandler(rp.UserinfoCallback(marshalUserinfo), o.provider)(w, r) } -func (o *Oidc) MFAPromptUI(w http.ResponseWriter, r *http.Request, username, ip string) { +func (o *Oidc) MFAPromptUI(w http.ResponseWriter, r *http.Request, _, _ string) { rp.AuthURLHandler(o.state, o.provider)(w, r) } -func (o *Oidc) RegistrationUI(w http.ResponseWriter, r *http.Request, username, ip string) { +func (o *Oidc) RegistrationUI(w http.ResponseWriter, r *http.Request, _, _ string) { o.RegistrationAPI(w, r) } diff --git a/internal/webserver/authenticators/pam.go b/internal/webserver/authenticators/pam.go index f6e7cceb..03e8933d 100644 --- a/internal/webserver/authenticators/pam.go +++ b/internal/webserver/authenticators/pam.go @@ -77,7 +77,9 @@ func (t *Pam) RegistrationAPI(w http.ResponseWriter, r *http.Request) { } log.Println(user.Username, clientTunnelIp, "authorised") - user.EnforceMFA() + if err := user.EnforceMFA(); err != nil { + log.Println(user.Username, clientTunnelIp, "failed to enforce mfa: ", err) + } default: http.NotFound(w, r) @@ -182,7 +184,7 @@ func (t *Pam) AuthoriseFunc(w http.ResponseWriter, r *http.Request) types.Authen } } -func (t *Pam) MFAPromptUI(w http.ResponseWriter, r *http.Request, username, ip string) { +func (t *Pam) MFAPromptUI(w http.ResponseWriter, _ *http.Request, username, ip string) { if err := resources.Render("prompt_mfa_pam.html", w, &resources.Msg{ HelpMail: data.GetHelpMail(), NumMethods: NumberOfMethods(), @@ -191,7 +193,7 @@ func (t *Pam) MFAPromptUI(w http.ResponseWriter, r *http.Request, username, ip s } } -func (t *Pam) RegistrationUI(w http.ResponseWriter, r *http.Request, username, ip string) { +func (t *Pam) RegistrationUI(w http.ResponseWriter, _ *http.Request, username, ip string) { if err := resources.Render("register_mfa_pam.html", w, &resources.Msg{ HelpMail: data.GetHelpMail(), NumMethods: NumberOfMethods(), diff --git a/internal/webserver/authenticators/totp.go b/internal/webserver/authenticators/totp.go index b8365c21..79a632f1 100644 --- a/internal/webserver/authenticators/totp.go +++ b/internal/webserver/authenticators/totp.go @@ -135,7 +135,9 @@ func (t *Totp) RegistrationAPI(w http.ResponseWriter, r *http.Request) { } log.Println(user.Username, clientTunnelIp, "authorised") - user.EnforceMFA() + if err := user.EnforceMFA(); err != nil { + log.Println(user.Username, clientTunnelIp, "enforce mfa failed:", err) + } default: http.NotFound(w, r) @@ -216,7 +218,7 @@ func (t *Totp) AuthoriseFunc(w http.ResponseWriter, r *http.Request) types.Authe } } -func (t *Totp) MFAPromptUI(w http.ResponseWriter, r *http.Request, username, ip string) { +func (t *Totp) MFAPromptUI(w http.ResponseWriter, _ *http.Request, username, ip string) { if err := resources.Render("prompt_mfa_totp.html", w, &resources.Msg{ HelpMail: data.GetHelpMail(), @@ -226,7 +228,7 @@ func (t *Totp) MFAPromptUI(w http.ResponseWriter, r *http.Request, username, ip } } -func (t *Totp) RegistrationUI(w http.ResponseWriter, r *http.Request, username, ip string) { +func (t *Totp) RegistrationUI(w http.ResponseWriter, _ *http.Request, username, ip string) { if err := resources.Render("register_mfa_totp.html", w, &resources.Msg{ HelpMail: data.GetHelpMail(), diff --git a/internal/webserver/authenticators/webauthn.go b/internal/webserver/authenticators/webauthn.go index dd9aa80a..59baabb7 100644 --- a/internal/webserver/authenticators/webauthn.go +++ b/internal/webserver/authenticators/webauthn.go @@ -35,9 +35,9 @@ func (wa *Webauthn) Init() error { } wa.webauthnExecutor, err = webauthn.New(&webauthn.Config{ - RPDisplayName: d.DisplayName, // Display Name for your site - RPID: d.ID, // Generally the domain name for your site - RPOrigin: d.Origin, // The origin URL for WebAuthn requests + RPDisplayName: d.DisplayName, // Display Name for your site + RPID: d.ID, // Generally the domain name for your site + RPOrigins: []string{d.Origin}, // The origin URL for WebAuthn requests }) if err != nil { return err @@ -288,7 +288,7 @@ func (wa *Webauthn) AuthorisationAPI(w http.ResponseWriter, r *http.Request) { } } -func (wa *Webauthn) MFAPromptUI(w http.ResponseWriter, r *http.Request, username, ip string) { +func (wa *Webauthn) MFAPromptUI(w http.ResponseWriter, _ *http.Request, username, ip string) { if err := resources.Render("prompt_mfa_webauthn.html", w, &resources.Msg{ HelpMail: data.GetHelpMail(), @@ -298,7 +298,7 @@ func (wa *Webauthn) MFAPromptUI(w http.ResponseWriter, r *http.Request, username } } -func (wa *Webauthn) RegistrationUI(w http.ResponseWriter, r *http.Request, username, ip string) { +func (wa *Webauthn) RegistrationUI(w http.ResponseWriter, _ *http.Request, username, ip string) { if err := resources.Render("register_mfa_webauthn.html", w, &resources.Msg{ HelpMail: data.GetHelpMail(), @@ -392,7 +392,7 @@ func randomUint64() uint64 { // WebAuthnID returns the user's ID func (u WebauthnUser) WebAuthnID() []byte { buf := make([]byte, binary.MaxVarintLen64) - binary.PutUvarint(buf, uint64(u.id)) + binary.PutUvarint(buf, u.id) return buf } @@ -418,7 +418,7 @@ func (u *WebauthnUser) AddCredential(cred webauthn.Credential) { } -// WebAuthnCredentials returns credentials owned by the user +// WebAuthnCredential returns credential owned by the user func (u WebauthnUser) WebAuthnCredential(ID []byte) (out *webauthn.Credential) { return u.credentials[string(ID)] @@ -437,7 +437,7 @@ func (u WebauthnUser) WebAuthnCredentials() (out []*webauthn.Credential) { // with all the user's credentials func (u WebauthnUser) CredentialExcludeList() []protocol.CredentialDescriptor { - credentialExcludeList := []protocol.CredentialDescriptor{} + var credentialExcludeList []protocol.CredentialDescriptor for _, cred := range u.credentials { descriptor := protocol.CredentialDescriptor{ Type: protocol.PublicKeyCredentialType, diff --git a/internal/webserver/statemachine.go b/internal/webserver/statemachine.go index df88c313..9b7a3677 100644 --- a/internal/webserver/statemachine.go +++ b/internal/webserver/statemachine.go @@ -33,7 +33,7 @@ func registerListeners() error { } // OidcDetailsKey = "wag-config-authentication-oidc" -func oidcChanges(key string, current data.OIDC, previous data.OIDC, et data.EventType) error { +func oidcChanges(_ string, _ data.OIDC, _ data.OIDC, et data.EventType) error { switch et { case data.DELETED: authenticators.DisableMethods(types.Oidc) @@ -56,7 +56,7 @@ func oidcChanges(key string, current data.OIDC, previous data.OIDC, et data.Even } // DomainKey = "wag-config-authentication-domain" -func domainChanged(key string, current string, _ string, et data.EventType) error { +func domainChanged(_ string, _ string, _ string, et data.EventType) error { switch et { case data.MODIFIED: @@ -77,7 +77,7 @@ func domainChanged(key string, current string, _ string, et data.EventType) erro } // MethodsEnabledKey = "wag-config-authentication-methods" -func enabledMethodsChanged(key string, current []string, previous []string, et data.EventType) (err error) { +func enabledMethodsChanged(_ string, current, previous []string, et data.EventType) (err error) { switch et { case data.DELETED: authenticators.DisableMethods(authenticators.StringsToMFA(previous)...) @@ -100,7 +100,7 @@ func enabledMethodsChanged(key string, current []string, previous []string, et d } // IssuerKey = "wag-config-authentication-issuer" -func issuerKeyChanged(key string, current string, previous string, et data.EventType) error { +func issuerKeyChanged(_ string, _, _ string, et data.EventType) error { switch et { case data.DELETED: authenticators.DisableMethods(types.Totp, types.Webauthn) diff --git a/internal/webserver/web.go b/internal/webserver/web.go index d9c5b94c..fe8b2d37 100644 --- a/internal/webserver/web.go +++ b/internal/webserver/web.go @@ -5,6 +5,7 @@ import ( "crypto/tls" "encoding/base64" "encoding/json" + "errors" "fmt" "html/template" "image/png" @@ -62,9 +63,6 @@ func Teardown() { func Start(errChan chan<- error) error { //https://blog.cloudflare.com/exposing-go-on-the-internet/ tlsConfig := &tls.Config{ - // Causes servers to use Go's default ciphersuite preferences, - // which are tuned to avoid attacks. Does nothing on clients. - PreferServerCipherSuites: true, // Only use curves which have assembly implementations CurvePreferences: []tls.CurveID{ tls.CurveP256, @@ -99,7 +97,7 @@ func Start(errChan chan<- error) error { Handler: setSecurityHeaders(public), } - if err := publicTLSServ.ListenAndServeTLS(config.Values.Webserver.Public.CertPath, config.Values.Webserver.Public.KeyPath); err != nil && err != http.ErrServerClosed { + if err := publicTLSServ.ListenAndServeTLS(config.Values.Webserver.Public.CertPath, config.Values.Webserver.Public.KeyPath); err != nil && !errors.Is(err, http.ErrServerClosed) { errChan <- fmt.Errorf("TLS webserver public listener failed: %v", err) } }() @@ -142,7 +140,7 @@ func Start(errChan chan<- error) error { Handler: setSecurityHeaders(public), } - if err := publicHTTPServ.ListenAndServe(); err != nil && err != http.ErrServerClosed { + if err := publicHTTPServ.ListenAndServe(); err != nil && !errors.Is(err, http.ErrServerClosed) { errChan <- fmt.Errorf("HTTP webserver public listener failed: %v", err) } }() @@ -194,7 +192,7 @@ func Start(errChan chan<- error) error { TLSConfig: tlsConfig, Handler: setSecurityHeaders(tunnel), } - if err := tunnelTLSServ.ListenAndServeTLS(config.Values.Webserver.Tunnel.CertPath, config.Values.Webserver.Tunnel.KeyPath); err != nil && err != http.ErrServerClosed { + if err := tunnelTLSServ.ListenAndServeTLS(config.Values.Webserver.Tunnel.CertPath, config.Values.Webserver.Tunnel.KeyPath); err != nil && errors.Is(err, http.ErrServerClosed) { errChan <- fmt.Errorf("TLS webserver tunnel listener failed: %v", err) } @@ -229,14 +227,14 @@ func Start(errChan chan<- error) error { Handler: setSecurityHeaders(tunnel), } - if err := tunnelHTTPServ.ListenAndServe(); err != nil && err != http.ErrServerClosed { + if err := tunnelHTTPServ.ListenAndServe(); err != nil && errors.Is(err, http.ErrServerClosed) { errChan <- fmt.Errorf("webserver tunnel listener failed: %v", err) } }() } - //Group the print statement so that multithreading wont disorder them + //Group the print statement so that multithreading won't disorder them log.Println("Started listening:\n", "\t\t\tTunnel Listener: ", tunnelListenAddress, "\n", "\t\t\tPublic Listener: ", config.Values.Webserver.Public.ListenAddress) @@ -384,10 +382,10 @@ func authorise(w http.ResponseWriter, r *http.Request) { mfaMethod.MFAPromptUI(w, r, user.Username, clientTunnelIp.String()) } -func reachability(w http.ResponseWriter, r *http.Request) { +func reachability(w http.ResponseWriter, _ *http.Request) { w.Header().Add("Content-Type", "text/plain") - isDrained, err := data.IsDrained(data.GetServerID()) + isDrained, err := data.IsDrained(data.GetServerID().String()) if err != nil { http.Error(w, "Failed to fetch state", http.StatusInternalServerError) return @@ -581,8 +579,8 @@ func registerDevice(w http.ResponseWriter, r *http.Request) { if r.URL.Query().Get("type") == "mobile" { w.Header().Set("Content-Type", "text/html; charset=UTF-8") - var config bytes.Buffer - err = resources.RenderWithFuncs("interface.tmpl", &config, &wireguardInterface, template.FuncMap{ + var wireguardProfile bytes.Buffer + err = resources.RenderWithFuncs("interface.tmpl", &wireguardProfile, &wireguardInterface, template.FuncMap{ "StringsJoin": strings.Join, "Unescape": func(s string) template.HTML { return template.HTML(s) }, }) @@ -592,7 +590,7 @@ func registerDevice(w http.ResponseWriter, r *http.Request) { return } - image, err := qr.Encode(config.String(), qr.M, qr.Auto) + image, err := qr.Encode(wireguardProfile.String(), qr.M, qr.Auto) if err != nil { log.Println(username, remoteAddr, "failed to generate qr code:", err) http.Error(w, "Server Error", http.StatusInternalServerError) @@ -614,12 +612,12 @@ func registerDevice(w http.ResponseWriter, r *http.Request) { return } - qr := resources.QrCodeRegistrationDisplay{ + qrCodeBytes := resources.QrCodeRegistrationDisplay{ ImageData: template.URL("data:image/png;base64, " + base64.StdEncoding.EncodeToString(buff.Bytes())), Username: username, } - err = resources.Render("qrcode_registration.html", w, &qr) + err = resources.Render("qrcode_registration.html", w, &qrCodeBytes) if err != nil { log.Println(username, remoteAddr, "failed to execute template to show qr code wireguard config:", err) http.Error(w, "Server Error", http.StatusInternalServerError) @@ -701,7 +699,7 @@ func routes(w http.ResponseWriter, r *http.Request) { remoteAddress := utils.GetIPFromRequest(r) user, err := users.GetUserFromAddress(remoteAddress) if err != nil { - log.Println(user.Username, remoteAddress, "Could not find user: ", err) + log.Println("unknown", remoteAddress, "Could not find user: ", err) http.Error(w, "Server Error", http.StatusInternalServerError) return } @@ -728,7 +726,7 @@ func status(w http.ResponseWriter, r *http.Request) { remoteAddress := utils.GetIPFromRequest(r) user, err := users.GetUserFromAddress(remoteAddress) if err != nil { - log.Println(user.Username, remoteAddress, "Could not find user: ", err) + log.Println("unknown", remoteAddress, "Could not find user: ", err) http.Error(w, "Server Error", http.StatusInternalServerError) return } diff --git a/main.go b/main.go index b7ce62ec..05d16c2b 100644 --- a/main.go +++ b/main.go @@ -1,6 +1,7 @@ package main import ( + "errors" "flag" "fmt" "log" @@ -70,7 +71,7 @@ func root(args []string) error { } if err := cmd.Check(); err != nil { - if err != flag.ErrHelp { + if errors.Is(err, flag.ErrHelp) { fmt.Println("Error: ", err.Error()) cmd.PrintUsage() } diff --git a/pkg/control/server/server.go b/pkg/control/server/server.go index 169a57fb..25321cb0 100644 --- a/pkg/control/server/server.go +++ b/pkg/control/server/server.go @@ -147,7 +147,7 @@ func StartControlSocket() error { Handler: controlMux, } - srv.Serve(l) + log.Println("failed to serve control socket: ", srv.Serve(l)) }() return nil } diff --git a/pkg/control/server/users.go b/pkg/control/server/users.go index e56a6755..c35a5d8b 100644 --- a/pkg/control/server/users.go +++ b/pkg/control/server/users.go @@ -43,13 +43,13 @@ func listUsers(w http.ResponseWriter, r *http.Request) { return } - users, err := data.GetAllUsers() + currentUsers, err := data.GetAllUsers() if err != nil { http.Error(w, err.Error(), 500) return } - b, err := json.Marshal(users) + b, err := json.Marshal(currentUsers) if err != nil { http.Error(w, err.Error(), 500) return @@ -186,13 +186,13 @@ func listAdminUsers(w http.ResponseWriter, r *http.Request) { return } - users, err := data.GetAllAdminUsers() + currentAdminUsers, err := data.GetAllAdminUsers() if err != nil { http.Error(w, err.Error(), 500) return } - b, err := json.Marshal(users) + b, err := json.Marshal(currentAdminUsers) if err != nil { http.Error(w, err.Error(), 500) return @@ -241,7 +241,11 @@ func unlockAdminUser(w http.ResponseWriter, r *http.Request) { username := r.FormValue("username") - data.SetAdminUserUnlock(username) + err = data.SetAdminUserUnlock(username) + if err != nil { + http.Error(w, err.Error(), 500) + return + } log.Println(username, "admin unlocked") diff --git a/pkg/control/wagctl/client.go b/pkg/control/wagctl/client.go index fbbda37a..69662600 100644 --- a/pkg/control/wagctl/client.go +++ b/pkg/control/wagctl/client.go @@ -54,7 +54,7 @@ func (c *CtrlClient) simplepost(path string, form url.Values) error { return nil } -// List devices, if the username field is empty (""), then list all devices. Otherwise list the one device corrosponding to the set username +// ListDevice if the username field is empty (""), then list all devices. Otherwise list the one device corrosponding to the set username func (c *CtrlClient) ListDevice(username string) (d []data.Device, err error) { response, err := c.httpClient.Get("http://unix/device/list?username=" + url.QueryEscape(username)) @@ -282,12 +282,12 @@ func (c *CtrlClient) GetPolicies() (result []control.PolicyData, err error) { // Add wag rule func (c *CtrlClient) AddPolicy(policies control.PolicyData) error { - data, err := json.Marshal(policies) + policiesData, err := json.Marshal(policies) if err != nil { return err } - response, err := c.httpClient.Post("http://unix/config/policy/create", "application/json", bytes.NewBuffer(data)) + response, err := c.httpClient.Post("http://unix/config/policy/create", "application/json", bytes.NewBuffer(policiesData)) if err != nil { return err } @@ -307,12 +307,12 @@ func (c *CtrlClient) AddPolicy(policies control.PolicyData) error { // Edit wag rule func (c *CtrlClient) EditPolicies(policy control.PolicyData) error { - data, err := json.Marshal(policy) + polciesData, err := json.Marshal(policy) if err != nil { return err } - response, err := c.httpClient.Post("http://unix/config/policy/edit", "application/json", bytes.NewBuffer(data)) + response, err := c.httpClient.Post("http://unix/config/policy/edit", "application/json", bytes.NewBuffer(polciesData)) if err != nil { return err } @@ -331,12 +331,12 @@ func (c *CtrlClient) EditPolicies(policy control.PolicyData) error { func (c *CtrlClient) RemovePolicies(policyNames []string) error { - data, err := json.Marshal(policyNames) + policiesData, err := json.Marshal(policyNames) if err != nil { return err } - response, err := c.httpClient.Post("http://unix/config/policies/delete", "application/json", bytes.NewBuffer(data)) + response, err := c.httpClient.Post("http://unix/config/policies/delete", "application/json", bytes.NewBuffer(policiesData)) if err != nil { return err } @@ -372,7 +372,7 @@ func (c *CtrlClient) GetGroups() (result []control.GroupData, err error) { // Add wag group/s func (c *CtrlClient) AddGroup(group control.GroupData) error { - data, err := json.Marshal(group) + groupData, err := json.Marshal(group) if err != nil { return err } @@ -381,7 +381,7 @@ func (c *CtrlClient) AddGroup(group control.GroupData) error { return errors.New("group did not have the 'group:' prefix") } - response, err := c.httpClient.Post("http://unix/config/group/create", "application/json", bytes.NewBuffer(data)) + response, err := c.httpClient.Post("http://unix/config/group/create", "application/json", bytes.NewBuffer(groupData)) if err != nil { return err } @@ -401,7 +401,7 @@ func (c *CtrlClient) AddGroup(group control.GroupData) error { // Edit wag group members func (c *CtrlClient) EditGroup(group control.GroupData) error { - data, err := json.Marshal(group) + groupData, err := json.Marshal(group) if err != nil { return err } @@ -410,7 +410,7 @@ func (c *CtrlClient) EditGroup(group control.GroupData) error { return errors.New("group did not have the 'group:' prefix") } - response, err := c.httpClient.Post("http://unix/config/group/edit", "application/json", bytes.NewBuffer(data)) + response, err := c.httpClient.Post("http://unix/config/group/edit", "application/json", bytes.NewBuffer(groupData)) if err != nil { return err } @@ -429,12 +429,12 @@ func (c *CtrlClient) EditGroup(group control.GroupData) error { func (c *CtrlClient) RemoveGroup(groupNames []string) error { - data, err := json.Marshal(groupNames) + groupData, err := json.Marshal(groupNames) if err != nil { return err } - response, err := c.httpClient.Post("http://unix/config/group/delete", "application/json", bytes.NewBuffer(data)) + response, err := c.httpClient.Post("http://unix/config/group/delete", "application/json", bytes.NewBuffer(groupData)) if err != nil { return err } diff --git a/ui/clustering.go b/ui/clustering.go index 4cfe7314..0b0e8494 100644 --- a/ui/clustering.go +++ b/ui/clustering.go @@ -14,6 +14,7 @@ import ( type MembershipDTO struct { *membership.Member IsDrained bool + IsWitness bool Ping string Status string @@ -48,14 +49,21 @@ func clusterMembersUI(w http.ResponseWriter, r *http.Request) { }, Leader: data.GetLeader(), - CurrentNode: data.GetServerID(), + CurrentNode: data.GetServerID().String(), } members := data.GetMembers() for i := range data.GetMembers() { drained, err := data.IsDrained(members[i].ID.String()) if err != nil { - log.Println("unable to render clustering page: ", err) + log.Println("unable to get drained state: ", err) + w.WriteHeader(http.StatusInternalServerError) + return + } + + witness, err := data.IsWitness(members[i].ID.String()) + if err != nil { + log.Println("unable to witness state: ", err) w.WriteHeader(http.StatusInternalServerError) return } @@ -92,6 +100,7 @@ func clusterMembersUI(w http.ResponseWriter, r *http.Request) { d.Members = append(d.Members, MembershipDTO{ Member: members[i], IsDrained: drained, + IsWitness: witness, Status: status, Ping: ping, }) @@ -184,7 +193,7 @@ func nodeControl(w http.ResponseWriter, r *http.Request) { log.Println("attempting to remove node ", ncR.Node) - if data.GetServerID() == ncR.Node { + if data.GetServerID().String() == ncR.Node { log.Println("user tried to remove current operating node from cluster") http.Error(w, "cannot remove current node", http.StatusBadRequest) return diff --git a/ui/devices.go b/ui/devices.go index 62105134..c8c2335a 100644 --- a/ui/devices.go +++ b/ui/devices.go @@ -63,10 +63,10 @@ func devicesMgmt(w http.ResponseWriter, r *http.Request) { return } - data := []DevicesData{} + var deviceData []DevicesData for _, dev := range allDevices { - data = append(data, DevicesData{ + deviceData = append(deviceData, DevicesData{ Owner: dev.Username, Locked: dev.Attempts >= lockout, InternalIP: dev.Address, @@ -76,7 +76,7 @@ func devicesMgmt(w http.ResponseWriter, r *http.Request) { }) } - b, err := json.Marshal(data) + b, err := json.Marshal(deviceData) if err != nil { log.Println("unable to marshal devices data: ", err) diff --git a/ui/diagnostics.go b/ui/diagnostics.go index 6a4ee143..72dab093 100644 --- a/ui/diagnostics.go +++ b/ui/diagnostics.go @@ -105,7 +105,7 @@ func wgDiagnositicsData(w http.ResponseWriter, r *http.Request) { return } - data := []WgDevicesData{} + var wireguardPeers []WgDevicesData for _, peer := range peers { ip := "-" @@ -113,7 +113,7 @@ func wgDiagnositicsData(w http.ResponseWriter, r *http.Request) { ip = peer.AllowedIPs[0].String() } - data = append(data, WgDevicesData{ + wireguardPeers = append(wireguardPeers, WgDevicesData{ ReceiveBytes: peer.ReceiveBytes, TransmitBytes: peer.TransmitBytes, @@ -125,7 +125,7 @@ func wgDiagnositicsData(w http.ResponseWriter, r *http.Request) { }) } - result, err := json.Marshal(data) + result, err := json.Marshal(wireguardPeers) if err != nil { log.Println("unable to marshal peers data: ", err) http.Error(w, "Bad Request", http.StatusBadRequest) diff --git a/ui/groups.go b/ui/groups.go index ae4e97e1..1b7b8a43 100644 --- a/ui/groups.go +++ b/ui/groups.go @@ -76,6 +76,7 @@ func groups(w http.ResponseWriter, r *http.Request) { } w.Write([]byte("OK")) + return case "PUT": var group control.GroupData err := json.NewDecoder(r.Body).Decode(&group) @@ -92,6 +93,7 @@ func groups(w http.ResponseWriter, r *http.Request) { } w.Write([]byte("OK")) + return case "POST": var group control.GroupData err := json.NewDecoder(r.Body).Decode(&group) @@ -110,6 +112,7 @@ func groups(w http.ResponseWriter, r *http.Request) { } w.Write([]byte("OK")) + return default: http.NotFound(w, r) return diff --git a/ui/notifications.go b/ui/notifications.go index b147f46f..e11eb350 100644 --- a/ui/notifications.go +++ b/ui/notifications.go @@ -152,10 +152,10 @@ func startUpdateChecker(notifications chan<- Notification) { log.Println("unable to fetch updates: ", err) return } - defer resp.Body.Close() var gr githubResponse err = json.NewDecoder(resp.Body).Decode(&gr) + resp.Body.Close() if err != nil { log.Println("unable to parse update json: ", err) return diff --git a/ui/policies.go b/ui/policies.go index b47ee008..a8cfa7ac 100644 --- a/ui/policies.go +++ b/ui/policies.go @@ -77,6 +77,7 @@ func policies(w http.ResponseWriter, r *http.Request) { } w.Write([]byte("OK")) + return case "PUT": var group control.PolicyData err := json.NewDecoder(r.Body).Decode(&group) @@ -95,6 +96,7 @@ func policies(w http.ResponseWriter, r *http.Request) { } w.Write([]byte("OK")) + return case "POST": var policy control.PolicyData err := json.NewDecoder(r.Body).Decode(&policy) @@ -113,6 +115,7 @@ func policies(w http.ResponseWriter, r *http.Request) { } w.Write([]byte("OK")) + return default: http.NotFound(w, r) return diff --git a/ui/registration.go b/ui/registration.go index d460db54..caf2453d 100644 --- a/ui/registration.go +++ b/ui/registration.go @@ -59,10 +59,10 @@ func registrationTokens(w http.ResponseWriter, r *http.Request) { return } - data := []TokensData{} + var tokens []TokensData for _, reg := range registrations { - data = append(data, TokensData{ + tokens = append(tokens, TokensData{ Username: reg.Username, Token: reg.Token, Groups: reg.Groups, @@ -71,7 +71,7 @@ func registrationTokens(w http.ResponseWriter, r *http.Request) { }) } - b, err := json.Marshal(data) + b, err := json.Marshal(tokens) if err != nil { http.Error(w, "Bad request", http.StatusBadRequest) return @@ -106,6 +106,7 @@ func registrationTokens(w http.ResponseWriter, r *http.Request) { } w.Write([]byte("OK")) + return case "POST": @@ -152,6 +153,7 @@ func registrationTokens(w http.ResponseWriter, r *http.Request) { } w.Write([]byte("OK")) + return default: http.NotFound(w, r) diff --git a/ui/security.go b/ui/security.go index 0c767181..726e2262 100644 --- a/ui/security.go +++ b/ui/security.go @@ -12,7 +12,7 @@ type security struct { func (sh *security) ServeHTTP(w http.ResponseWriter, r *http.Request) { w.Header().Set("X-Frame-Options", "DENY") w.Header().Set("Strict-Transport-Security", "max-age=31536000") - w.Header().Set("ContentTypeNosniff", "nosniff") + w.Header().Set("X-Content-Type-Options", "nosniff") if r.Method != "GET" { u, err := url.Parse(r.Header.Get("Origin")) diff --git a/ui/statemanager.go b/ui/statemanager.go index b9e9b00f..1f4f96b8 100644 --- a/ui/statemanager.go +++ b/ui/statemanager.go @@ -9,5 +9,5 @@ var ( func watchClusterHealth(state string) { clusterState = state - serverID = data.GetServerID() + serverID = data.GetServerID().String() } diff --git a/ui/templates/cluster/members.html b/ui/templates/cluster/members.html index 7f0a189b..f08989bf 100755 --- a/ui/templates/cluster/members.html +++ b/ui/templates/cluster/members.html @@ -46,7 +46,7 @@
{{if Role:
- {{if eq .ID $.Leader}}Leader{{else if .IsLearner}}Learner{{else}}Member{{end}} + {{if eq .ID $.Leader}}Leader{{else if .IsLearner}}Learner{{else if .IsWitness}}Witness{{else}}Member{{end}}
@@ -90,11 +90,13 @@
{{if {{end}} + {{if not .IsWitness}} {{if .IsDrained}} Restore{{else}} Drain{{end}} + {{end}}
diff --git a/ui/ui_webserver.go b/ui/ui_webserver.go index fda88cdb..f86aff61 100644 --- a/ui/ui_webserver.go +++ b/ui/ui_webserver.go @@ -4,6 +4,7 @@ import ( "crypto/rand" "crypto/tls" "encoding/hex" + "errors" "fmt" "html" "html/template" @@ -84,7 +85,7 @@ func render(w http.ResponseWriter, r *http.Request, model interface{}, content . parsed, err = template.New(name).Funcs(funcsMap).ParseFS(templatesContent, content...) } else { - realFiles := []string{} + var realFiles []string for _, c := range content { realFiles = append(realFiles, filepath.Join("ui/", c)) } @@ -226,7 +227,7 @@ func StartWebServer(errs chan<- error) error { if data.HasLeader() { clusterState = "healthy" } - serverID = data.GetServerID() + serverID = data.GetServerID().String() _, err = data.RegisterClusterHealthListener(watchClusterHealth) if err != nil { @@ -237,9 +238,6 @@ func StartWebServer(errs chan<- error) error { //https://blog.cloudflare.com/exposing-go-on-the-internet/ tlsConfig := &tls.Config{ - // Causes servers to use Go's default ciphersuite preferences, - // which are tuned to avoid attacks. Does nothing on clients. - PreferServerCipherSuites: true, // Only use curves which have assembly implementations CurvePreferences: []tls.CurveID{ tls.CurveP256, @@ -369,7 +367,7 @@ func StartWebServer(errs chan<- error) error { Handler: setSecurityHeaders(allRoutes), } - if err := HTTPSServer.ListenAndServeTLS(config.Values.ManagementUI.CertPath, config.Values.ManagementUI.KeyPath); err != nil && err != http.ErrServerClosed { + if err := HTTPSServer.ListenAndServeTLS(config.Values.ManagementUI.CertPath, config.Values.ManagementUI.KeyPath); err != nil && !errors.Is(err, http.ErrServerClosed) { errs <- fmt.Errorf("TLS management listener failed: %v", err) } @@ -383,7 +381,7 @@ func StartWebServer(errs chan<- error) error { IdleTimeout: 120 * time.Second, Handler: setSecurityHeaders(allRoutes), } - if err := HTTPServer.ListenAndServe(); err != nil && err != http.ErrServerClosed { + if err := HTTPServer.ListenAndServe(); err != nil && !errors.Is(err, http.ErrServerClosed) { errs <- fmt.Errorf("webserver management listener failed: %v", HTTPServer.ListenAndServe()) } diff --git a/ui/users.go b/ui/users.go index b6889e8d..a2954082 100644 --- a/ui/users.go +++ b/ui/users.go @@ -61,7 +61,7 @@ func manageUsers(w http.ResponseWriter, r *http.Request) { return } - usersData := []UsersData{} + var usersData []UsersData for _, u := range users { devices, _ := ctrl.ListDevice(u.Username) @@ -90,6 +90,7 @@ func manageUsers(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "application/json") w.Write(b) + return case "PUT": var action struct { Action string `json:"action"` @@ -133,6 +134,7 @@ func manageUsers(w http.ResponseWriter, r *http.Request) { } w.Write([]byte("OK")) + return case "DELETE": var usernames []string @@ -160,6 +162,7 @@ func manageUsers(w http.ResponseWriter, r *http.Request) { } w.Write([]byte("OK")) + return default: http.NotFound(w, r)