diff --git a/internal/allocation/allocation_manager.go b/internal/allocation/allocation_manager.go index d56694de..15688c4f 100644 --- a/internal/allocation/allocation_manager.go +++ b/internal/allocation/allocation_manager.go @@ -30,7 +30,7 @@ type Manager struct { lock sync.RWMutex log logging.LeveledLogger - allocations map[string]*Allocation + allocations map[FiveTupleFingerprint]*Allocation reservations []*reservation allocatePacketConn func(network string, requestedPort int) (net.PacketConn, net.Addr, error) @@ -51,7 +51,7 @@ func NewManager(config ManagerConfig) (*Manager, error) { return &Manager{ log: config.LeveledLogger, - allocations: make(map[string]*Allocation, 64), + allocations: make(map[FiveTupleFingerprint]*Allocation, 64), allocatePacketConn: config.AllocatePacketConn, allocateConn: config.AllocateConn, permissionHandler: config.PermissionHandler, diff --git a/internal/allocation/five_tuple.go b/internal/allocation/five_tuple.go index 99264623..010021b3 100644 --- a/internal/allocation/five_tuple.go +++ b/internal/allocation/five_tuple.go @@ -4,7 +4,6 @@ package allocation import ( - "fmt" "net" ) @@ -33,7 +32,31 @@ func (f *FiveTuple) Equal(b *FiveTuple) bool { return f.Fingerprint() == b.Fingerprint() } +// FiveTupleFingerprint is a comparable representation of a FiveTuple +type FiveTupleFingerprint struct { + srcIP, dstIP [16]byte + srcPort, dstPort uint16 + protocol Protocol +} + // Fingerprint is the identity of a FiveTuple -func (f *FiveTuple) Fingerprint() string { - return fmt.Sprintf("%d_%s_%s", f.Protocol, f.SrcAddr.String(), f.DstAddr.String()) +func (f *FiveTuple) Fingerprint() (fp FiveTupleFingerprint) { + switch f.Protocol { + case UDP: + src := f.SrcAddr.(*net.UDPAddr) + copy(fp.srcIP[:], src.IP) + fp.srcPort = uint16(src.Port) + dst := f.SrcAddr.(*net.UDPAddr) + copy(fp.dstIP[:], dst.IP) + fp.dstPort = uint16(dst.Port) + case TCP: + src := f.SrcAddr.(*net.TCPAddr) + copy(fp.srcIP[:], src.IP) + fp.srcPort = uint16(src.Port) + dst := f.SrcAddr.(*net.TCPAddr) + copy(fp.dstIP[:], dst.IP) + fp.dstPort = uint16(dst.Port) + } + fp.protocol = f.Protocol + return }