Skip to content

Commit

Permalink
Improve Firewall
Browse files Browse the repository at this point in the history
  • Loading branch information
parttimenerd committed Aug 28, 2024
1 parent 2af668f commit 687b3a3
Show file tree
Hide file tree
Showing 3 changed files with 45 additions and 23 deletions.
61 changes: 41 additions & 20 deletions bpf-samples/src/main/java/me/bechberger/ebpf/samples/Firewall.java
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
import me.bechberger.ebpf.bpf.map.BPFRingBuffer;
import me.bechberger.ebpf.type.Enum;
import me.bechberger.ebpf.type.Ptr;
import org.jetbrains.annotations.Nullable;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

Expand All @@ -26,17 +25,18 @@ public abstract class Firewall extends BPFProgram implements XDPHook, BasePacket
private static final Logger logger = LoggerFactory.getLogger(Firewall.class);

@Type
record IPAndPort(int ip, int port) {
record IPAndPort(int ip, int sourcePort, int destPort) {
}

@Type
record LogEntry(IPAndPort connection, long timeInMs) {
}
@Type
record LogEntry(IPAndPort connection, long timeInMs) {
}

@Type
record FirewallRule(@Unsigned int ip,
int ignoreLowBytes,
int port) {
int sourcePort,
int destPort) {
}

@Type
Expand Down Expand Up @@ -66,17 +66,29 @@ static int zeroLowBytes(int ip, int ignoreLowBytes) {
@AlwaysInline
FirewallAction computeSpecificAction(Ptr<IPAndPort> info, int ignoreLowBytes) {
int ip = info.val().ip;
var sourcePort = info.val().sourcePort;
var destPort = info.val().destPort;
// first null the bytes that should be ignored
int matchingAddressBytes = zeroLowBytes(ip, ignoreLowBytes);
if (matchingAddressBytes == 100) { // don't ask
bpf_trace_printk("Checking rule for %d:%d\n", matchingAddressBytes, info.val().port);
if (matchingAddressBytes == 0) { // don't ask
bpf_trace_printk("Checking rule for %d:%d\n", matchingAddressBytes, sourcePort);
}
var rule = new FirewallRule(matchingAddressBytes, ignoreLowBytes, info.val().port);
var rule = new FirewallRule(matchingAddressBytes, ignoreLowBytes, sourcePort, destPort);
var action = firewallRules.bpf_get(rule);
if (action != null) {
return action.val();
}
rule = new FirewallRule(matchingAddressBytes, ignoreLowBytes, -1);
rule = new FirewallRule(matchingAddressBytes, ignoreLowBytes, sourcePort, -1);
action = firewallRules.bpf_get(rule);
if (action != null) {
return action.val();
}
rule = new FirewallRule(matchingAddressBytes, ignoreLowBytes, -1, destPort);
action = firewallRules.bpf_get(rule);
if (action != null) {
return action.val();
}
rule = new FirewallRule(matchingAddressBytes, ignoreLowBytes, -1, -1);
action = firewallRules.bpf_get(rule);
if (action != null) {
return action.val();
Expand All @@ -101,21 +113,21 @@ FirewallAction computeAction(Ptr<IPAndPort> info) {
@AlwaysInline
FirewallAction getAction(Ptr<PacketInfo> packetInfo) {
IPAndPort ipAndPort = new IPAndPort(
packetInfo.val().source.ipv4(), packetInfo.val().sourcePort);
packetInfo.val().source.ipv4(), packetInfo.val().sourcePort, packetInfo.val().destinationPort);
Ptr<FirewallAction> action = resolvedRules.bpf_get(ipAndPort);
if (action != null) {
return action.val();
}
var newAction = computeAction(Ptr.of(ipAndPort));
bpf_trace_printk("Unresolved action for %d:%d %d\n", ipAndPort.ip(), ipAndPort.port(), newAction.value());
bpf_trace_printk("Unresolved action for %d:%d %d\n", ipAndPort.ip(), ipAndPort.sourcePort, newAction.value());
resolvedRules.put(ipAndPort, newAction);
return newAction;
}

@BPFFunction
@AlwaysInline
void countConnection(PacketInfo info) {
IPAndPort ipAndPort = new IPAndPort(info.source.ipv4(), info.sourcePort);
IPAndPort ipAndPort = new IPAndPort(info.source.ipv4(), info.sourcePort, info.destinationPort);
Ptr<Long> count = connectionCount.bpf_get(ipAndPort);
if (count == null) {
long one = 1;
Expand All @@ -132,7 +144,7 @@ void recordBlockedConnection(PacketInfo info) {
return;
}
ptr.set(
new LogEntry(new IPAndPort(info.source.ipv4(), info.sourcePort),
new LogEntry(new IPAndPort(info.source.ipv4(), info.sourcePort, info.destinationPort),
bpf_ktime_get_ns() / 1000000));
blockedConnections.submit(ptr);
}
Expand Down Expand Up @@ -168,17 +180,25 @@ static FirewallRuleAndAction parseRule(String rule) {
FirewallRule firewallRule;
var rulePart = rule.split(" ")[0];
if (rule.contains("/")) {
if (!rulePart.matches(".*/(0|8|16|32):.*(:.*)?")) {
throw new IllegalArgumentException("Invalid rule: " + rule + ", should match .*/(0|8|16|32):.*(:.*)?");
}
String[] parts = rulePart.split(":");
String[] ipParts = parts[0].split("/");
int ip = NetworkUtil.ipAddressToInt(ipParts[0]);
int ignoreLowBytes = 32 - Integer.parseInt(ipParts[1]) / 8;
int port = parsePort(parts[1]);
firewallRule = new FirewallRule(zeroLowBytes(ip, ignoreLowBytes), ignoreLowBytes, port);
int sourcePort = parsePort(parts[1]);
int targetPort = parts.length == 3 ? parsePort(parts[2]) : -1;
firewallRule = new FirewallRule(zeroLowBytes(ip, ignoreLowBytes), ignoreLowBytes, sourcePort, targetPort);
} else {
if (!rulePart.matches(".+:.*(:.*)?")) {
throw new IllegalArgumentException("Invalid rule: " + rule + ", should match .+:.*(:.*)?");
}
String[] parts = rulePart.split(":");
int ip = NetworkUtil.getFirstIPAddress(parts[0]);
int port = parsePort(parts[1]);
firewallRule = new FirewallRule(ip, 0, port);
int sourcePort = parsePort(parts[1]);
int targetPort = parts.length == 3 ? parsePort(parts[2]) : -1;
firewallRule = new FirewallRule(ip, 0, sourcePort, targetPort);
}
var actionPart = rule.split(" ")[1];
FirewallAction action = switch (actionPart) {
Expand All @@ -198,8 +218,9 @@ public static void main(String[] args) throws InterruptedException {
}
program.xdpAttach();
program.blockedConnections.setCallback((info) -> {
logger.info("Blocked packet from {} port {}",
NetworkUtil.intToIpAddress(info.connection.ip).getHostAddress(), info.connection.port);
logger.info("Blocked packet from {} port {} to port {}",
NetworkUtil.intToIpAddress(info.connection.ip).getHostAddress(),
info.connection.sourcePort, info.connection.destPort);
});
while (true) {
program.consumeAndThrow();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -177,14 +177,14 @@ public String index() {
<h3>Send Custom JSON to /rawDrop</h3>
<p>Like <code>{"ip": 0, "ignoreLowBytes": 4, "port": 443}</code></p>
<div class="input-group">
<input type="text" id="jsonInput" value='{"ip": 0, "ignoreLowBytes": 4, "port": 443}'>
<input type="text" id="jsonInput" value='{"ip": 0, "ignoreLowBytes": 4, "sourcePort": 443, "destPort": -1}'>
<button onclick="sendJson()">Send JSON</button>
</div>
</div>
<div>
<h3>Add a Rule to /add</h3>
<p>Like <code>google.com:HTTP drop</code></p>
<p>Like <code>google.com:HTTPS drop</code></p>
<div class="input-group">
<input type="text" id="ruleInput">
<button onclick="addRule()">Add Rule</button>
Expand Down
3 changes: 2 additions & 1 deletion bpf/src/main/java/me/bechberger/ebpf/bpf/XDPHook.java
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
package me.bechberger.ebpf.bpf;

import me.bechberger.ebpf.annotations.Unsigned;
import me.bechberger.ebpf.annotations.bpf.MethodIsBPFRelatedFunction;
import me.bechberger.ebpf.annotations.bpf.BPFFunction;
import me.bechberger.ebpf.annotations.bpf.BuiltinBPFFunction;
Expand All @@ -25,7 +26,7 @@ public interface XDPHook {
short ETH_P_8021Q = (short)0x8100;
short ETH_P_8021AD = (short)0x88A8;
short ETH_P_IP = (short)0x0800;
short ETH_P_IPV6 = (short)0x86DD;
int ETH_P_IPV6 = 0x86DD;
short ETH_P_ARP = (short)0x0806;

/**
Expand Down

0 comments on commit 687b3a3

Please sign in to comment.