Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added TrustedProxies parameter #339

Open
wants to merge 21 commits into
base: unstable
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 13 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions client/src/pages/config/users/configman.jsx
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,7 @@ const ConfigManagement = () => {
GenerateMissingAuthCert: config.HTTPConfig.GenerateMissingAuthCert,
HTTPPort: config.HTTPConfig.HTTPPort,
HTTPSPort: config.HTTPConfig.HTTPSPort,
TrustedProxies: config.HTTPConfig.TrustedProxies && config.HTTPConfig.TrustedProxies.join(', '),
SSLEmail: config.HTTPConfig.SSLEmail,
UseWildcardCertificate: config.HTTPConfig.UseWildcardCertificate,
HTTPSCertificateMode: config.HTTPConfig.HTTPSCertificateMode,
Expand Down Expand Up @@ -205,6 +206,8 @@ const ConfigManagement = () => {
AllowSearchEngine: values.AllowSearchEngine,
AllowHTTPLocalIPAccess: values.AllowHTTPLocalIPAccess,
PublishMDNS: values.PublishMDNS,
TrustedProxies: (values.TrustedProxies && values.TrustedProxies != "") ?
values.TrustedProxies.split(',').map((x) => x.trim()) : [],
},
EmailConfig: {
...config.EmailConfig,
Expand Down Expand Up @@ -615,6 +618,19 @@ const ConfigManagement = () => {
)}
</Stack>
</Grid>

<Grid item xs={12}>
<Stack spacing={1}>
<Alert severity="info">
{t('mgmt.config.http.TrustedProxiesInput.TrustesProxiesHelperText')}<br />
</Alert>
<CosmosInputText
label={t('mgmt.config.http.TrustedProxiesInput.TrustesProxiesLabel')}
name="TrustedProxies"
formik={formik}
/>
</Stack>
</Grid>
<Grid item xs={12}>
<Alert severity="info">
{t('mgmt.config.http.allowSearchIndexCheckbox')}<br />
Expand Down
2 changes: 2 additions & 0 deletions client/src/utils/locales/en/translation.json
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,8 @@
"mgmt.config.http.hostnameInput.HostnameLabel": "Hostname: This will be used to restrict access to your Cosmos Server (Your IP, or your domain name)",
"mgmt.config.http.hostnameInput.HostnameValidation": "Hostname is required",
"mgmt.config.http.publishMDNSCheckbox": "This allows you to publish your server on your local network using mDNS. This means all your .local domains will be available on your local network with no additional config.",
"mgmt.config.http.TrustedProxiesInput.TrustesProxiesLabel": "Trusted proxies allow X-Forwarded-For from IP/IP range.",
"mgmt.config.http.TrustedProxiesInput.TrustesProxiesHelperText": "Use this setting when you have an upstream proxy server to avoid it being blocked by Shield. IPs or IP ranges separated by commas.",
"mgmt.config.email.notifyLoginCheckbox.notifyLoginLabel": "Notify Users upon Successful Login",
"mgmt.config.proxy.noRoutesConfiguredText": "No routes configured.",
"mgmt.config.proxy.originTitle": "Origin",
Expand Down
2 changes: 2 additions & 0 deletions client/src/utils/locales/fr/translation.json
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,8 @@
"mgmt.config.http.hostnameInput.HostnameLabel": "Nom d'hôte : Cela sera utilisé pour restreindre l'accès à votre serveur Cosmos (Votre IP, ou votre nom de domaine)",
"mgmt.config.http.hostnameInput.HostnameValidation": "Le nom d'hôte est obligatoire",
"mgmt.config.http.publishMDNSCheckbox": "Cela vous permet de publier votre serveur sur votre réseau local en utilisant mDNS. Cela signifie que tous vos domaines .local seront disponibles sur votre réseau local sans configuration supplémentaire.",
"mgmt.config.http.TrustedProxiesInput.TrustesProxiesLabel": "IPs/Plages IP des proxys de confiance pour l'utilisation de X-Forwarded-For.",
"mgmt.config.http.TrustedProxiesInput.TrustesProxiesHelperText": "Utilisez ce paramètre lorsque vous avez un serveur proxy en amont pour éviter le blocage de celui-ci par le Shield. IPs ou plages IP séparées par des virgules.",
"mgmt.config.email.notifyLoginCheckbox.notifyLoginLabel": "Notifier les utilisateurs en cas de connexion réussie",
"mgmt.config.proxy.noRoutesConfiguredText": "Aucune route configurée.",
"mgmt.config.proxy.originTitle": "Origine",
Expand Down
2 changes: 2 additions & 0 deletions src/httpServer.go
Original file line number Diff line number Diff line change
Expand Up @@ -359,6 +359,8 @@ func InitServer() *mux.Router {

router := mux.NewRouter().StrictSlash(true)

router.Use(utils.ClientRealIP)

router.Use(utils.BlockBannedIPs)

router.Use(utils.Logger)
Expand Down
22 changes: 14 additions & 8 deletions src/proxy/shield.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"fmt"
"math"
"strconv"
"strings"

"github.com/azukaar/cosmos-server/src/utils"
"github.com/azukaar/cosmos-server/src/metrics"
Expand Down Expand Up @@ -296,14 +297,19 @@ func calculateLowestExhaustedPercentage(policy utils.SmartShieldPolicy, userCons
func GetClientID(r *http.Request, route utils.ProxyRouteConfig) string {
// when using Docker we need to get the real IP
remoteAddr, _ := utils.SplitIP(r.RemoteAddr)
UseForwardedFor := utils.GetMainConfig().HTTPConfig.UseForwardedFor
isTunneledIp := constellation.GetDeviceIp(route.TunnelVia) == remoteAddr
isConstIP := utils.IsConstellationIP(remoteAddr)
isConstTokenValid := constellation.CheckConstellationToken(r) == nil

if (UseForwardedFor && r.Header.Get("x-forwarded-for") != "") ||
(isTunneledIp && isConstIP && isConstTokenValid) {
ip, _ := utils.SplitIP(r.Header.Get("x-forwarded-for"))
useForwardedForHeader := false
if r.Header.Get("x-forwarded-for") != "" {
useForwardedForHeader = utils.IsTrustedProxy(remoteAddr)
if !useForwardedForHeader {
isTunneledIp := constellation.GetDeviceIp(route.TunnelVia) == remoteAddr
isConstIP := utils.IsConstellationIP(remoteAddr)
isConstTokenValid := constellation.CheckConstellationToken(r) == nil
useForwardedForHeader = isTunneledIp && isConstIP && isConstTokenValid
}
}

if useForwardedForHeader {
ip, _ := utils.SplitIP(strings.TrimSpace(strings.Split(r.Header.Get("X-Forwarded-For"), ",")[0]))
utils.Debug("SmartShield: Getting forwarded client ID " + ip)
return ip
} else {
Expand Down
83 changes: 51 additions & 32 deletions src/utils/middleware.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,34 +48,49 @@ func getIPAbuseCounter(ip string) int64 {
return atomic.LoadInt64(&counter.val)
}

func ClientRealIP(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
clientID := GetClientIP(r)
if(clientID == ""){
http.Error(w, "Invalid request", http.StatusBadRequest)
return
}

ctx := context.WithValue(r.Context(), "ClientID", clientID)
r = r.WithContext(ctx)

next.ServeHTTP(w, r)
})
}

func BlockBannedIPs(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ip, _, err := net.SplitHostPort(r.RemoteAddr)
if err != nil {
if hj, ok := w.(http.Hijacker); ok {
conn, _, err := hj.Hijack()
if err == nil {
conn.Close()
}
}
return
ip, ok := r.Context().Value("ClientID").(string)
if !ok {
if hj, ok := w.(http.Hijacker); ok {
conn, _, err := hj.Hijack()
if err == nil {
conn.Close()
}
}
return
}

nbAbuse := getIPAbuseCounter(ip)
nbAbuse := getIPAbuseCounter(ip)

if nbAbuse > 275 {
Warn("IP " + ip + " has " + fmt.Sprintf("%d", nbAbuse) + " abuse(s) and will soon be banned.")
}
Warn("IP " + ip + " has " + fmt.Sprintf("%d", nbAbuse) + " abuse(s) and will soon be banned.")
}

if nbAbuse > 300 {
if hj, ok := w.(http.Hijacker); ok {
conn, _, err := hj.Hijack()
if err == nil {
conn.Close()
}
}
return
if hj, ok := w.(http.Hijacker); ok {
conn, _, err := hj.Hijack()
if err == nil {
conn.Close()
}
}
return
}

next.ServeHTTP(w, r)
})
Expand Down Expand Up @@ -204,8 +219,8 @@ func GetIPLocation(ip string) (string, error) {
func BlockByCountryMiddleware(blockedCountries []string, CountryBlacklistIsWhitelist bool) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ip, _, err := net.SplitHostPort(r.RemoteAddr)
if err != nil {
ip, ok := r.Context().Value("ClientID").(string)
if !ok {
http.Error(w, "Invalid request", http.StatusBadRequest)
return
}
Expand Down Expand Up @@ -287,7 +302,7 @@ func BlockPostWithoutReferer(next http.Handler) http.Handler {
Error("Blocked POST request without Referer header", nil)
http.Error(w, "Bad Request: Invalid request.", http.StatusBadRequest)

ip, _, _ := net.SplitHostPort(r.RemoteAddr)
ip, _ := r.Context().Value("ClientID").(string)
if ip != "" {
TriggerEvent(
"cosmos.proxy.shield.referer",
Expand Down Expand Up @@ -346,7 +361,7 @@ func EnsureHostname(next http.Handler) http.Handler {
w.WriteHeader(http.StatusBadRequest)
http.Error(w, "Bad Request: Invalid hostname. Use your domain instead of your IP to access your server. Check logs if more details are needed.", http.StatusBadRequest)

ip, _, _ := net.SplitHostPort(r.RemoteAddr)
ip, _ := r.Context().Value("ClientID").(string)
if ip != "" {
TriggerEvent(
"cosmos.proxy.shield.hostname",
Expand Down Expand Up @@ -412,7 +427,7 @@ func EnsureHostnameCosmosAPI(next http.Handler) http.Handler {
w.WriteHeader(http.StatusBadRequest)
http.Error(w, "Bad Request: Invalid hostname. Use your domain instead of your IP to access your server. Check logs if more details are needed.", http.StatusBadRequest)

ip, _, _ := net.SplitHostPort(r.RemoteAddr)
ip, _ := r.Context().Value("ClientID").(string)
if ip != "" {
TriggerEvent(
"cosmos.proxy.shield.hostname",
Expand Down Expand Up @@ -466,33 +481,37 @@ func IsValidHostname(hostname string) bool {
}

func IPInRange(ipStr, cidrStr string) (bool, error) {
_, cidrNet, err := net.ParseCIDR(cidrStr)
if err != nil {
return false, fmt.Errorf("parse CIDR range: %w", err)
}

ip := net.ParseIP(ipStr)
if ip == nil {
return false, fmt.Errorf("parse IP: invalid IP address")
}

_, cidrNet, err := net.ParseCIDR(cidrStr)
if err != nil {
if ipStr == cidrStr {
return true, nil
}
return false, fmt.Errorf("parse CIDR range: %w", err)
}

return cidrNet.Contains(ip), nil
}

func Restrictions(RestrictToConstellation bool, WhitelistInboundIPs []string) func(next http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {

ip, _, err := net.SplitHostPort(r.RemoteAddr)
if err != nil {
remoteAddr, _, err := net.SplitHostPort(r.RemoteAddr)
ip, ok := r.Context().Value("ClientID").(string)
if (err != nil) || !ok {
http.Error(w, "Invalid request", http.StatusBadRequest)
return
}

isUsingWhiteList := len(WhitelistInboundIPs) > 0

isInWhitelist := false
isInConstellation := strings.HasPrefix(ip, "192.168.201.") || strings.HasPrefix(ip, "192.168.202.")
isInConstellation := strings.HasPrefix(remoteAddr, "192.168.201.") || strings.HasPrefix(remoteAddr, "192.168.202.")

for _, ipRange := range WhitelistInboundIPs {
Debug("Checking if " + ip + " is in " + ipRange)
Expand Down
1 change: 1 addition & 0 deletions src/utils/types.go
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,7 @@ type HTTPConfig struct {
UseForwardedFor bool
AllowSearchEngine bool
PublishMDNS bool
TrustedProxies []string
}

const (
Expand Down
21 changes: 16 additions & 5 deletions src/utils/utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -779,11 +779,13 @@ func DownloadFile(url string) (string, error) {
}

func GetClientIP(req *http.Request) string {
/*ip := req.Header.Get("X-Forwarded-For")
if ip == "" {
ip = req.RemoteAddr
}*/
return req.RemoteAddr
// when using Docker we need to get the real IP
remoteAddr, _ := SplitIP(req.RemoteAddr)

if req.Header.Get("x-forwarded-for") != "" && IsTrustedProxy(remoteAddr) {
remoteAddr, _ = SplitIP(strings.TrimSpace(strings.Split(req.Header.Get("X-Forwarded-For"), ",")[0]))
}
return remoteAddr
}

func IsDomain(domain string) bool {
Expand Down Expand Up @@ -899,6 +901,15 @@ func IsConstellationIP(ip string) bool {
return false
}

func IsTrustedProxy(ip string) bool {
InterN0te marked this conversation as resolved.
Show resolved Hide resolved
for _, trustedProxy := range GetMainConfig().HTTPConfig.TrustedProxies {
if isInRange, _ := IPInRange(ip, trustedProxy); isInRange {
return true
}
}
return false
}

func SplitIP(ipPort string) (string, string) {
host, port, err := osnet.SplitHostPort(ipPort)
if err != nil {
Expand Down