From 6eefa1c986ef0f7c777155b97784aac354228e05 Mon Sep 17 00:00:00 2001 From: Gabor Retvari Date: Thu, 5 Dec 2024 10:59:56 +0100 Subject: [PATCH] Add server-side lifecycle event handler API --- internal/allocation/allocation.go | 95 +++- internal/allocation/allocation_manager.go | 35 +- .../allocation/allocation_manager_test.go | 20 +- internal/allocation/allocation_test.go | 22 +- internal/allocation/channel_bind_test.go | 2 +- internal/allocation/event_handler.go | 39 ++ internal/allocation/five_tuple.go | 15 +- internal/client/binding_test.go | 2 +- internal/client/periodic_timer_test.go | 2 +- internal/proto/chandata.go | 2 +- internal/server/nonce.go | 4 +- internal/server/server.go | 3 +- internal/server/turn.go | 11 +- internal/server/turn_test.go | 2 +- internal/server/util.go | 32 ++ lt_cred.go | 2 +- relay_address_generator_range.go | 2 +- server.go | 7 + server_config.go | 80 +++ server_test.go | 458 +++++++++++++++++- 20 files changed, 788 insertions(+), 47 deletions(-) create mode 100644 internal/allocation/event_handler.go diff --git a/internal/allocation/allocation.go b/internal/allocation/allocation.go index 5b5ff369..9d0286dc 100644 --- a/internal/allocation/allocation.go +++ b/internal/allocation/allocation.go @@ -35,6 +35,8 @@ type Allocation struct { channelBindings []*ChannelBind lifetimeTimer *time.Timer closed chan interface{} + username, realm string + callback EventHandler log logging.LeveledLogger // Some clients (Firefox or others using resiprocate's nICE lib) may retry allocation @@ -45,12 +47,13 @@ type Allocation struct { } // NewAllocation creates a new instance of NewAllocation. -func NewAllocation(turnSocket net.PacketConn, fiveTuple *FiveTuple, log logging.LeveledLogger) *Allocation { +func NewAllocation(turnSocket net.PacketConn, fiveTuple *FiveTuple, callback EventHandler, log logging.LeveledLogger) *Allocation { return &Allocation{ TurnSocket: turnSocket, fiveTuple: fiveTuple, permissions: make(map[string]*Permission, 64), closed: make(chan interface{}), + callback: callback, log: log, } } @@ -81,6 +84,21 @@ func (a *Allocation) AddPermission(p *Permission) { a.permissions[fingerprint] = p a.permissionsLock.Unlock() + if a.callback != nil { + if u, ok := p.Addr.(*net.UDPAddr); ok { + a.log.Trace("Calling OnPermissionCreated event handler") + a.callback(EventHandlerArgs{ + Type: OnPermissionCreated, + SrcAddr: a.fiveTuple.SrcAddr, + DstAddr: a.fiveTuple.DstAddr, + Protocol: a.fiveTuple.Protocol, + Username: a.username, + Realm: a.realm, + PeerIP: u.IP, + }) + } + } + p.start(permissionTimeout) } @@ -89,6 +107,32 @@ func (a *Allocation) RemovePermission(addr net.Addr) { a.permissionsLock.Lock() defer a.permissionsLock.Unlock() delete(a.permissions, ipnet.FingerprintAddr(addr)) + + if a.callback != nil { + if u, ok := addr.(*net.UDPAddr); ok { + a.log.Trace("Calling OnPermissionDeleted event handler") + a.callback(EventHandlerArgs{ + Type: OnPermissionDeleted, + SrcAddr: a.fiveTuple.SrcAddr, + DstAddr: a.fiveTuple.DstAddr, + Protocol: a.fiveTuple.Protocol, + Username: a.username, + Realm: a.realm, + PeerIP: u.IP, + }) + } + } +} + +// ListPermissions returns the permissions associated with an allocation. +func (a *Allocation) ListPermissions() []*Permission { + ps := []*Permission{} + a.permissionsLock.RLock() + defer a.permissionsLock.RUnlock() + for _, p := range a.permissions { + ps = append(ps, p) + } + return ps } // AddChannelBind adds a new ChannelBind to the allocation, it also updates the @@ -113,6 +157,21 @@ func (a *Allocation) AddChannelBind(c *ChannelBind, lifetime time.Duration) erro // Channel binds also refresh permissions. a.AddPermission(NewPermission(c.Peer, a.log)) + + if a.callback != nil { + a.log.Trace("Calling OnChannelCreated event handler") + a.callback(EventHandlerArgs{ + Type: OnChannelCreated, + SrcAddr: a.fiveTuple.SrcAddr, + DstAddr: a.fiveTuple.DstAddr, + Protocol: a.fiveTuple.Protocol, + Username: a.username, + Realm: a.realm, + RelayAddr: a.RelayAddr, + PeerAddr: c.Peer, + ChannelNumber: uint16(c.Number), + }) + } } else { channelByNumber.refresh(lifetime) @@ -130,6 +189,21 @@ func (a *Allocation) RemoveChannelBind(number proto.ChannelNumber) bool { for i := len(a.channelBindings) - 1; i >= 0; i-- { if a.channelBindings[i].Number == number { + if a.callback != nil { + a.log.Trace("Calling OnChannelDeleted event handler") + a.callback(EventHandlerArgs{ + Type: OnChannelDeleted, + SrcAddr: a.fiveTuple.SrcAddr, + DstAddr: a.fiveTuple.DstAddr, + Protocol: a.fiveTuple.Protocol, + Username: a.username, + Realm: a.realm, + RelayAddr: a.RelayAddr, + PeerAddr: a.channelBindings[i].Peer, + ChannelNumber: uint16(a.channelBindings[i].Number), + }) + } + a.channelBindings = append(a.channelBindings[:i], a.channelBindings[i+1:]...) return true } @@ -162,6 +236,15 @@ func (a *Allocation) GetChannelByAddr(addr net.Addr) *ChannelBind { return nil } +// ListChannelBindings returns the channel bindings associated with an allocation. +func (a *Allocation) ListChannelBindings() []*ChannelBind { + cs := []*ChannelBind{} + a.channelBindingsLock.RLock() + defer a.channelBindingsLock.RUnlock() + cs = append(cs, a.channelBindings...) + return cs +} + // Refresh updates the allocations lifetime func (a *Allocation) Refresh(lifetime time.Duration) { if !a.lifetimeTimer.Reset(lifetime) { @@ -196,17 +279,15 @@ func (a *Allocation) Close() error { a.lifetimeTimer.Stop() - a.permissionsLock.RLock() - for _, p := range a.permissions { + for _, p := range a.ListPermissions() { + a.RemovePermission(p.Addr) p.lifetimeTimer.Stop() } - a.permissionsLock.RUnlock() - a.channelBindingsLock.RLock() - for _, c := range a.channelBindings { + for _, c := range a.ListChannelBindings() { + a.RemoveChannelBind(c.Number) c.lifetimeTimer.Stop() } - a.channelBindingsLock.RUnlock() return a.RelaySocket.Close() } diff --git a/internal/allocation/allocation_manager.go b/internal/allocation/allocation_manager.go index 2b765921..6567e8a1 100644 --- a/internal/allocation/allocation_manager.go +++ b/internal/allocation/allocation_manager.go @@ -18,6 +18,7 @@ type ManagerConfig struct { AllocatePacketConn func(network string, requestedPort int) (net.PacketConn, net.Addr, error) AllocateConn func(network string, requestedPort int) (net.Conn, net.Addr, error) PermissionHandler func(sourceAddr net.Addr, peerIP net.IP) bool + EventHandler EventHandler } type reservation struct { @@ -36,6 +37,7 @@ type Manager struct { allocatePacketConn func(network string, requestedPort int) (net.PacketConn, net.Addr, error) allocateConn func(network string, requestedPort int) (net.Conn, net.Addr, error) permissionHandler func(sourceAddr net.Addr, peerIP net.IP) bool + EventHandler EventHandler } // NewManager creates a new instance of Manager. @@ -55,6 +57,7 @@ func NewManager(config ManagerConfig) (*Manager, error) { allocatePacketConn: config.AllocatePacketConn, allocateConn: config.AllocateConn, permissionHandler: config.PermissionHandler, + EventHandler: config.EventHandler, }, nil } @@ -86,7 +89,7 @@ func (m *Manager) Close() error { } // CreateAllocation creates a new allocation and starts relaying -func (m *Manager) CreateAllocation(fiveTuple *FiveTuple, turnSocket net.PacketConn, requestedPort int, lifetime time.Duration) (*Allocation, error) { +func (m *Manager) CreateAllocation(fiveTuple *FiveTuple, turnSocket net.PacketConn, requestedPort int, lifetime time.Duration, username, realm string) (*Allocation, error) { switch { case fiveTuple == nil: return nil, errNilFiveTuple @@ -103,7 +106,9 @@ func (m *Manager) CreateAllocation(fiveTuple *FiveTuple, turnSocket net.PacketCo if a := m.GetAllocation(fiveTuple); a != nil { return nil, fmt.Errorf("%w: %v", errDupeFiveTuple, fiveTuple) } - a := NewAllocation(turnSocket, fiveTuple, m.log) + a := NewAllocation(turnSocket, fiveTuple, m.EventHandler, m.log) + a.username = username + a.realm = realm conn, relayAddr, err := m.allocatePacketConn("udp4", requestedPort) if err != nil { @@ -123,6 +128,20 @@ func (m *Manager) CreateAllocation(fiveTuple *FiveTuple, turnSocket net.PacketCo m.allocations[fiveTuple.Fingerprint()] = a m.lock.Unlock() + if m.EventHandler != nil { + m.log.Trace("Calling OnAllocationCreated event handler") + m.EventHandler(EventHandlerArgs{ + Type: OnAllocationCreated, + SrcAddr: fiveTuple.SrcAddr, + DstAddr: fiveTuple.DstAddr, + Protocol: UDP, + Username: username, + Realm: realm, + RelayAddr: relayAddr, + RequestedPort: requestedPort, + }) + } + go a.packetHandler(m) return a, nil } @@ -143,6 +162,18 @@ func (m *Manager) DeleteAllocation(fiveTuple *FiveTuple) { if err := allocation.Close(); err != nil { m.log.Errorf("Failed to close allocation: %v", err) } + + if m.EventHandler != nil { + m.log.Trace("Calling OnAllocationDeleted event handler") + m.EventHandler(EventHandlerArgs{ + Type: OnAllocationDeleted, + SrcAddr: fiveTuple.SrcAddr, + DstAddr: fiveTuple.DstAddr, + Protocol: UDP, + Username: allocation.username, + Realm: allocation.realm, + }) + } } // CreateReservation stores the reservation for the token+port diff --git a/internal/allocation/allocation_manager_test.go b/internal/allocation/allocation_manager_test.go index 014d85d3..350ec795 100644 --- a/internal/allocation/allocation_manager_test.go +++ b/internal/allocation/allocation_manager_test.go @@ -52,13 +52,13 @@ func subTestCreateInvalidAllocation(t *testing.T, turnSocket net.PacketConn) { m, err := newTestManager() assert.NoError(t, err) - if a, err := m.CreateAllocation(nil, turnSocket, 0, proto.DefaultLifetime); a != nil || err == nil { + if a, err := m.CreateAllocation(nil, turnSocket, 0, proto.DefaultLifetime, "", ""); a != nil || err == nil { t.Errorf("Illegally created allocation with nil FiveTuple") } - if a, err := m.CreateAllocation(randomFiveTuple(), nil, 0, proto.DefaultLifetime); a != nil || err == nil { + if a, err := m.CreateAllocation(randomFiveTuple(), nil, 0, proto.DefaultLifetime, "", ""); a != nil || err == nil { t.Errorf("Illegally created allocation with nil turnSocket") } - if a, err := m.CreateAllocation(randomFiveTuple(), turnSocket, 0, 0); a != nil || err == nil { + if a, err := m.CreateAllocation(randomFiveTuple(), turnSocket, 0, 0, "", ""); a != nil || err == nil { t.Errorf("Illegally created allocation with 0 lifetime") } } @@ -69,7 +69,7 @@ func subTestCreateAllocation(t *testing.T, turnSocket net.PacketConn) { assert.NoError(t, err) fiveTuple := randomFiveTuple() - if a, err := m.CreateAllocation(fiveTuple, turnSocket, 0, proto.DefaultLifetime); a == nil || err != nil { + if a, err := m.CreateAllocation(fiveTuple, turnSocket, 0, proto.DefaultLifetime, "", ""); a == nil || err != nil { t.Errorf("Failed to create allocation %v %v", a, err) } @@ -84,11 +84,11 @@ func subTestCreateAllocationDuplicateFiveTuple(t *testing.T, turnSocket net.Pack assert.NoError(t, err) fiveTuple := randomFiveTuple() - if a, err := m.CreateAllocation(fiveTuple, turnSocket, 0, proto.DefaultLifetime); a == nil || err != nil { + if a, err := m.CreateAllocation(fiveTuple, turnSocket, 0, proto.DefaultLifetime, "", ""); a == nil || err != nil { t.Errorf("Failed to create allocation %v %v", a, err) } - if a, err := m.CreateAllocation(fiveTuple, turnSocket, 0, proto.DefaultLifetime); a != nil || err == nil { + if a, err := m.CreateAllocation(fiveTuple, turnSocket, 0, proto.DefaultLifetime, "", ""); a != nil || err == nil { t.Errorf("Was able to create allocation with same FiveTuple twice") } } @@ -98,7 +98,7 @@ func subTestDeleteAllocation(t *testing.T, turnSocket net.PacketConn) { assert.NoError(t, err) fiveTuple := randomFiveTuple() - if a, err := m.CreateAllocation(fiveTuple, turnSocket, 0, proto.DefaultLifetime); a == nil || err != nil { + if a, err := m.CreateAllocation(fiveTuple, turnSocket, 0, proto.DefaultLifetime, "", ""); a == nil || err != nil { t.Errorf("Failed to create allocation %v %v", a, err) } @@ -123,7 +123,7 @@ func subTestAllocationTimeout(t *testing.T, turnSocket net.PacketConn) { for index := range allocations { fiveTuple := randomFiveTuple() - a, err := m.CreateAllocation(fiveTuple, turnSocket, 0, lifetime) + a, err := m.CreateAllocation(fiveTuple, turnSocket, 0, lifetime, "", "") if err != nil { t.Errorf("Failed to create allocation with %v", fiveTuple) } @@ -147,9 +147,9 @@ func subTestManagerClose(t *testing.T, turnSocket net.PacketConn) { allocations := make([]*Allocation, 2) - a1, _ := m.CreateAllocation(randomFiveTuple(), turnSocket, 0, time.Second) + a1, _ := m.CreateAllocation(randomFiveTuple(), turnSocket, 0, time.Second, "", "") allocations[0] = a1 - a2, _ := m.CreateAllocation(randomFiveTuple(), turnSocket, 0, time.Minute) + a2, _ := m.CreateAllocation(randomFiveTuple(), turnSocket, 0, time.Minute, "", "") allocations[1] = a2 // Make a1 timeout diff --git a/internal/allocation/allocation_test.go b/internal/allocation/allocation_test.go index 49269d68..cb2d4d14 100644 --- a/internal/allocation/allocation_test.go +++ b/internal/allocation/allocation_test.go @@ -46,7 +46,7 @@ func TestAllocation(t *testing.T) { } func subTestGetPermission(t *testing.T) { - a := NewAllocation(nil, nil, nil) + a := NewAllocation(nil, nil, nil, nil) addr, err := net.ResolveUDPAddr("udp", "127.0.0.1:3478") if err != nil { @@ -88,7 +88,7 @@ func subTestGetPermission(t *testing.T) { } func subTestAddPermission(t *testing.T) { - a := NewAllocation(nil, nil, nil) + a := NewAllocation(nil, nil, nil, nil) addr, err := net.ResolveUDPAddr("udp", "127.0.0.1:3478") if err != nil { @@ -107,7 +107,7 @@ func subTestAddPermission(t *testing.T) { } func subTestRemovePermission(t *testing.T) { - a := NewAllocation(nil, nil, nil) + a := NewAllocation(nil, nil, nil, nil) addr, err := net.ResolveUDPAddr("udp", "127.0.0.1:3478") if err != nil { @@ -130,7 +130,7 @@ func subTestRemovePermission(t *testing.T) { } func subTestAddChannelBind(t *testing.T) { - a := NewAllocation(nil, nil, nil) + a := NewAllocation(nil, nil, nil, nil) addr, err := net.ResolveUDPAddr("udp", "127.0.0.1:3478") if err != nil { @@ -154,7 +154,7 @@ func subTestAddChannelBind(t *testing.T) { } func subTestGetChannelByNumber(t *testing.T) { - a := NewAllocation(nil, nil, nil) + a := NewAllocation(nil, nil, nil, nil) addr, err := net.ResolveUDPAddr("udp", "127.0.0.1:3478") if err != nil { @@ -173,7 +173,7 @@ func subTestGetChannelByNumber(t *testing.T) { } func subTestGetChannelByAddr(t *testing.T) { - a := NewAllocation(nil, nil, nil) + a := NewAllocation(nil, nil, nil, nil) addr, err := net.ResolveUDPAddr("udp", "127.0.0.1:3478") if err != nil { @@ -193,7 +193,7 @@ func subTestGetChannelByAddr(t *testing.T) { } func subTestRemoveChannelBind(t *testing.T) { - a := NewAllocation(nil, nil, nil) + a := NewAllocation(nil, nil, nil, nil) addr, err := net.ResolveUDPAddr("udp", "127.0.0.1:3478") if err != nil { @@ -214,7 +214,7 @@ func subTestRemoveChannelBind(t *testing.T) { } func subTestAllocationRefresh(t *testing.T) { - a := NewAllocation(nil, nil, nil) + a := NewAllocation(nil, nil, nil, nil) var wg sync.WaitGroup wg.Add(1) @@ -236,7 +236,7 @@ func subTestAllocationClose(t *testing.T) { panic(err) } - a := NewAllocation(nil, nil, nil) + a := NewAllocation(nil, nil, nil, nil) a.RelaySocket = l // Add mock lifetimeTimer a.lifetimeTimer = time.AfterFunc(proto.DefaultLifetime, func() {}) @@ -292,7 +292,7 @@ func subTestPacketHandler(t *testing.T) { a, err := m.CreateAllocation(&FiveTuple{ SrcAddr: clientListener.LocalAddr(), DstAddr: turnSocket.LocalAddr(), - }, turnSocket, 0, proto.DefaultLifetime) + }, turnSocket, 0, proto.DefaultLifetime, "", "") assert.Nil(t, err, "should succeed") @@ -357,7 +357,7 @@ func subTestPacketHandler(t *testing.T) { } func subTestResponseCache(t *testing.T) { - a := NewAllocation(nil, nil, nil) + a := NewAllocation(nil, nil, nil, nil) transactionID := [stun.TransactionIDSize]byte{1, 2, 3} responseAttrs := []stun.Setter{ &proto.Lifetime{ diff --git a/internal/allocation/channel_bind_test.go b/internal/allocation/channel_bind_test.go index 30e3034a..4d72e456 100644 --- a/internal/allocation/channel_bind_test.go +++ b/internal/allocation/channel_bind_test.go @@ -42,7 +42,7 @@ func TestChannelBindReset(t *testing.T) { } func newChannelBind(lifetime time.Duration) *ChannelBind { - a := NewAllocation(nil, nil, nil) + a := NewAllocation(nil, nil, nil, nil) addr, _ := net.ResolveUDPAddr("udp", "0.0.0.0:0") c := &ChannelBind{ diff --git a/internal/allocation/event_handler.go b/internal/allocation/event_handler.go new file mode 100644 index 00000000..b5b718c2 --- /dev/null +++ b/internal/allocation/event_handler.go @@ -0,0 +1,39 @@ +// SPDX-FileCopyrightText: 2023 The Pion community +// SPDX-License-Identifier: MIT + +package allocation + +import ( + "net" +) + +// EventHandlerType is a type for signaling low-level event callbacks to the server. +type EventHandlerType int + +// Event handler types. +const ( + UnknownEvent EventHandlerType = iota + OnAuth + OnAllocationCreated + OnAllocationDeleted + OnAllocationError + OnPermissionCreated + OnPermissionDeleted + OnChannelCreated + OnChannelDeleted +) + +// EventHandlerArgs is a set of arguments passed from the low-level event callbacks to the server. +type EventHandlerArgs struct { + Type EventHandlerType + SrcAddr, DstAddr, RelayAddr, PeerAddr net.Addr + Protocol Protocol + Username, Realm, Method, Message string + Verdict bool + RequestedPort int + PeerIP net.IP + ChannelNumber uint16 +} + +// EventHandler is a callback used by the server to surface allocation lifecycle events. +type EventHandler func(EventHandlerArgs) diff --git a/internal/allocation/five_tuple.go b/internal/allocation/five_tuple.go index 6d812caf..577424bc 100644 --- a/internal/allocation/five_tuple.go +++ b/internal/allocation/five_tuple.go @@ -16,6 +16,17 @@ const ( TCP ) +func (p Protocol) String() string { + switch p { + case UDP: + return "UDP" + case TCP: + return "TCP" + default: + return "" + } +} + // FiveTuple is the combination (client IP address and port, server IP // address and port, and transport protocol (currently one of UDP, // TCP, or TLS)) used to communicate between the client and the @@ -54,9 +65,9 @@ func (f *FiveTuple) Fingerprint() (fp FiveTupleFingerprint) { func netAddrIPAndPort(addr net.Addr) (net.IP, uint16) { switch a := addr.(type) { case *net.UDPAddr: - return a.IP.To16(), uint16(a.Port) + return a.IP.To16(), uint16(a.Port) //nolint:gosec case *net.TCPAddr: - return a.IP.To16(), uint16(a.Port) + return a.IP.To16(), uint16(a.Port) //nolint:gosec default: return nil, 0 } diff --git a/internal/client/binding_test.go b/internal/client/binding_test.go index 5fd982c1..fcd6c22c 100644 --- a/internal/client/binding_test.go +++ b/internal/client/binding_test.go @@ -54,7 +54,7 @@ func TestBindingManager(t *testing.T) { if i%2 == 0 { assert.True(t, m.deleteByAddr(addr), "should return true") } else { - assert.True(t, m.deleteByNumber(minChannelNumber+uint16(i)), "should return true") + assert.True(t, m.deleteByNumber(minChannelNumber+uint16(i)), "should return true") //nolint:gosec } } diff --git a/internal/client/periodic_timer_test.go b/internal/client/periodic_timer_test.go index d77a16b3..03fa71fe 100644 --- a/internal/client/periodic_timer_test.go +++ b/internal/client/periodic_timer_test.go @@ -34,7 +34,7 @@ func TestPeriodicTimer(t *testing.T) { time.Sleep(120 * time.Millisecond) rt.Stop() assert.False(t, rt.IsRunning(), "should not be running") - assert.Equal(t, 4, int(atomic.LoadUint64(&nCbs)), "should be called 4 times (actual: %d)", atomic.LoadUint64(&nCbs)) + assert.Equal(t, 4, int(atomic.LoadUint64(&nCbs)), "should be called 4 times (actual: %d)", atomic.LoadUint64(&nCbs)) //nolint:gosec }) t.Run("stop inside handler", func(t *testing.T) { diff --git a/internal/proto/chandata.go b/internal/proto/chandata.go index df937119..e4eee622 100644 --- a/internal/proto/chandata.go +++ b/internal/proto/chandata.go @@ -90,7 +90,7 @@ func (c *ChannelData) WriteHeader() { _ = c.Raw[:channelDataHeaderSize] binary.BigEndian.PutUint16(c.Raw[:channelDataNumberSize], uint16(c.Number)) binary.BigEndian.PutUint16(c.Raw[channelDataNumberSize:channelDataHeaderSize], - uint16(len(c.Data)), + uint16(len(c.Data)), //nolint:gosec ) } diff --git a/internal/server/nonce.go b/internal/server/nonce.go index b3f3131e..d6ba3207 100644 --- a/internal/server/nonce.go +++ b/internal/server/nonce.go @@ -37,7 +37,7 @@ type NonceHash struct { // Generate a nonce func (n *NonceHash) Generate() (string, error) { nonce := make([]byte, 8, nonceLength) - binary.BigEndian.PutUint64(nonce, uint64(time.Now().UnixMilli())) + binary.BigEndian.PutUint64(nonce, uint64(time.Now().UnixMilli())) //nolint:gosec hash := hmac.New(sha256.New, n.key) if _, err := hash.Write(nonce[:8]); err != nil { @@ -55,7 +55,7 @@ func (n *NonceHash) Validate(nonce string) error { return fmt.Errorf("%w: %v", errInvalidNonce, err) //nolint:errorlint } - if ts := time.UnixMilli(int64(binary.BigEndian.Uint64(b))); time.Since(ts) > nonceLifetime { + if ts := time.UnixMilli(int64(binary.BigEndian.Uint64(b))); time.Since(ts) > nonceLifetime { //nolint:gosec return errInvalidNonce } diff --git a/internal/server/server.go b/internal/server/server.go index 253492e9..9cc19199 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -27,7 +27,8 @@ type Request struct { NonceHash *NonceHash // User Configuration - AuthHandler func(username string, realm string, srcAddr net.Addr) (key []byte, ok bool) + AuthHandler func(username string, realm string, srcAddr net.Addr) (key []byte, ok bool) + Log logging.LeveledLogger Realm string ChannelBindTimeout time.Duration diff --git a/internal/server/turn.go b/internal/server/turn.go index 46e45ecb..5367b82d 100644 --- a/internal/server/turn.go +++ b/internal/server/turn.go @@ -115,6 +115,12 @@ func handleAllocateRequest(r Request, m *stun.Message) error { } } + // Parse realm and username (already checked in authenticateRequest) + realmAttr := &stun.Realm{} + _ = realmAttr.GetFrom(m) + usernameAttr := &stun.Username{} + _ = usernameAttr.GetFrom(m) + // 7. At any point, the server MAY choose to reject the request with a // 486 (Allocation Quota Reached) error if it feels the client is // trying to exceed some locally defined allocation quota. The @@ -131,7 +137,10 @@ func handleAllocateRequest(r Request, m *stun.Message) error { fiveTuple, r.Conn, requestedPort, - lifetimeDuration) + lifetimeDuration, + usernameAttr.String(), + realmAttr.String(), + ) if err != nil { return buildAndSendErr(r.Conn, r.SrcAddr, err, insufficientCapacityMsg...) } diff --git a/internal/server/turn_test.go b/internal/server/turn_test.go index e4a3b947..4c7703ca 100644 --- a/internal/server/turn_test.go +++ b/internal/server/turn_test.go @@ -97,7 +97,7 @@ func TestAllocationLifeTime(t *testing.T) { fiveTuple := &allocation.FiveTuple{SrcAddr: r.SrcAddr, DstAddr: r.Conn.LocalAddr(), Protocol: allocation.UDP} - _, err = r.AllocationManager.CreateAllocation(fiveTuple, r.Conn, 0, time.Hour) + _, err = r.AllocationManager.CreateAllocation(fiveTuple, r.Conn, 0, time.Hour, "", "") assert.NoError(t, err) assert.NotNil(t, r.AllocationManager.GetAllocation(fiveTuple)) diff --git a/internal/server/util.go b/internal/server/util.go index 7c01d329..f8642aea 100644 --- a/internal/server/util.go +++ b/internal/server/util.go @@ -10,6 +10,7 @@ import ( "time" "github.com/pion/stun/v3" + "github.com/pion/turn/v4/internal/allocation" "github.com/pion/turn/v4/internal/proto" ) @@ -90,16 +91,47 @@ func authenticateRequest(r Request, m *stun.Message, callingMethod stun.Method) ourKey, ok := r.AuthHandler(usernameAttr.String(), realmAttr.String(), r.SrcAddr) if !ok { + genAuthEvent(r, m, callingMethod, false) return nil, false, buildAndSendErr(r.Conn, r.SrcAddr, fmt.Errorf("%w %s", errNoSuchUser, usernameAttr.String()), badRequestMsg...) } if err := stun.MessageIntegrity(ourKey).Check(m); err != nil { + genAuthEvent(r, m, callingMethod, false) return nil, false, buildAndSendErr(r.Conn, r.SrcAddr, err, badRequestMsg...) } + genAuthEvent(r, m, callingMethod, true) return stun.MessageIntegrity(ourKey), true, nil } +func genAuthEvent(r Request, m *stun.Message, callingMethod stun.Method, verdict bool) { + if r.AllocationManager.EventHandler == nil { + return + } + + realmAttr := &stun.Realm{} + if err := realmAttr.GetFrom(m); err != nil { + return + } + + usernameAttr := &stun.Username{} + if err := usernameAttr.GetFrom(m); err != nil { + return + } + + r.Log.Trace("Calling OnAuth event handler") + r.AllocationManager.EventHandler(allocation.EventHandlerArgs{ + Type: allocation.OnAuth, + SrcAddr: r.SrcAddr, + DstAddr: r.Conn.LocalAddr(), + Protocol: allocation.UDP, + Username: usernameAttr.String(), + Realm: realmAttr.String(), + Method: callingMethod.String(), + Verdict: verdict, + }) +} + func allocationLifeTime(m *stun.Message) time.Duration { lifetimeDuration := proto.DefaultLifetime diff --git a/lt_cred.go b/lt_cred.go index 42466c38..bd3197f1 100644 --- a/lt_cred.go +++ b/lt_cred.go @@ -79,7 +79,7 @@ func LongTermTURNRESTAuthHandler(sharedSecret string, l logging.LeveledLogger) A l = logging.NewDefaultLoggerFactory().NewLogger("turn") } return func(username, realm string, srcAddr net.Addr) (key []byte, ok bool) { - l.Tracef("Authentication username=%q realm=%q srcAddr=%v\n", username, realm, srcAddr) + l.Tracef("Authentication username=%q realm=%q srcAddr=%v", username, realm, srcAddr) timestamp := strings.Split(username, ":")[0] t, err := strconv.Atoi(timestamp) if err != nil { diff --git a/relay_address_generator_range.go b/relay_address_generator_range.go index d87a57f9..8ebf55d2 100644 --- a/relay_address_generator_range.go +++ b/relay_address_generator_range.go @@ -84,7 +84,7 @@ func (r *RelayAddressGeneratorPortRange) AllocatePacketConn(network string, requ } for try := 0; try < r.MaxRetries; try++ { - port := r.MinPort + uint16(r.Rand.Intn(int((r.MaxPort+1)-r.MinPort))) + port := r.MinPort + uint16(r.Rand.Intn(int((r.MaxPort+1)-r.MinPort))) //nolint:gosec conn, err := r.Net.ListenPacket(network, fmt.Sprintf("%s:%d", r.Address, port)) if err != nil { continue diff --git a/server.go b/server.go index 3b58938f..b96b5ada 100644 --- a/server.go +++ b/server.go @@ -27,6 +27,7 @@ type Server struct { realm string channelBindTimeout time.Duration nonceHash *server.NonceHash + eventHandlers EventHandlers packetConnConfigs []PacketConnConfig listenerConfigs []ListenerConfig @@ -66,6 +67,7 @@ func NewServer(config ServerConfig) (*Server, error) { listenerConfigs: config.ListenerConfigs, nonceHash: nonceHash, inboundMTU: mtu, + eventHandlers: config.EventHandlers, } if s.channelBindTimeout == 0 { @@ -191,6 +193,7 @@ func (s *Server) createAllocationManager(addrGenerator RelayAddressGenerator, ha AllocatePacketConn: addrGenerator.AllocatePacketConn, AllocateConn: addrGenerator.AllocateConn, PermissionHandler: handler, + EventHandler: genericEventHandler(s.eventHandlers), LeveledLogger: s.log, }) if err != nil { @@ -226,6 +229,10 @@ func (s *Server) readLoop(p net.PacketConn, allocationManager *allocation.Manage ChannelBindTimeout: s.channelBindTimeout, NonceHash: s.nonceHash, }); err != nil { + if s.eventHandlers.OnAllocationError != nil { + s.log.Trace("Calling OnAllocationError event handler") + s.eventHandlers.OnAllocationError(addr, p.LocalAddr(), allocation.UDP.String(), err.Error()) + } s.log.Errorf("Failed to handle datagram: %v", err) } } diff --git a/server_config.go b/server_config.go index eab2988e..4dd55c6c 100644 --- a/server_config.go +++ b/server_config.go @@ -11,6 +11,7 @@ import ( "time" "github.com/pion/logging" + "github.com/pion/turn/v4/internal/allocation" ) // RelayAddressGenerator is used to generate a RelayAddress when creating an allocation. @@ -104,6 +105,82 @@ func GenerateAuthKey(username, realm, password string) []byte { return h.Sum(nil) } +// EventHandlers is a set of callbacks that the server will call at certain hook points during an +// allocation's lifecycle. All events are reported with the context that identifies the allocation +// triggering the event (source and destination address, protocol, username and realm used for +// authenticating the allocation). It is OK to handle only a subset of the callbacks. +type EventHandlers struct { + // OnAuth is called after an authentication request has been processed with the TURN method + // triggering the authentication request (either "Allocate", "Refresh" "CreatePermission", + // or "ChannelBind"), and the verdict is the authentication result. + OnAuth func(srcAddr, dstAddr net.Addr, protocol, username, realm string, method string, verdict bool) + // OnAllocationCreated is called after a new allocation has been made. The relayAddr + // argument specifies the relay address and requestedPort is the port requested by the + // client (if any). + OnAllocationCreated func(srcAddr, dstAddr net.Addr, protocol, username, realm string, relayAddr net.Addr, requestedPort int) + // OnAllocationDeleted is called after an allocation has been removed. + OnAllocationDeleted func(srcAddr, dstAddr net.Addr, protocol, username, realm string) + // OnAllocationError is called when the readloop hdndling an allocation exits with an + // error with an error message. + OnAllocationError func(srcAddr, dstAddr net.Addr, protocol, message string) + // OnPermissionCreated is called after a new permission has been made to an IP address. + OnPermissionCreated func(srcAddr, dstAddr net.Addr, protocol, username, realm string, peer net.IP) + // OnPermissionDeleted is called after a permission for a given IP address has been + // removed. + OnPermissionDeleted func(srcAddr, dstAddr net.Addr, protocol, username, realm string, peer net.IP) + // OnChannelCreated is called after a new channel has been made. The relay address, the + // peer address and the channel number can be used to uniquely identify the channel + // created. + OnChannelCreated func(srcAddr, dstAddr net.Addr, protocol, username, realm string, relayAddr, peer net.Addr, channelNumber uint16) + // OnChannelDeleted is called after a channel has been removed from the server. The relay + // address, the peer address and the channel number can be used to uniquely identify the + // channel deleted. + OnChannelDeleted func(srcAddr, dstAddr net.Addr, protocol, username, realm string, relayAddr, peer net.Addr, channelNumber uint16) +} + +func genericEventHandler(handlers EventHandlers) allocation.EventHandler { + return func(arg allocation.EventHandlerArgs) { + switch arg.Type { + case allocation.OnAuth: + if handlers.OnAuth != nil { + handlers.OnAuth(arg.SrcAddr, arg.DstAddr, arg.Protocol.String(), + arg.Username, arg.Realm, arg.Method, arg.Verdict) + } + case allocation.OnAllocationCreated: + if handlers.OnAllocationCreated != nil { + handlers.OnAllocationCreated(arg.SrcAddr, arg.DstAddr, arg.Protocol.String(), + arg.Username, arg.Realm, arg.RelayAddr, arg.RequestedPort) + } + case allocation.OnAllocationDeleted: + if handlers.OnAllocationDeleted != nil { + handlers.OnAllocationDeleted(arg.SrcAddr, arg.DstAddr, arg.Protocol.String(), + arg.Username, arg.Realm) + } + case allocation.OnPermissionCreated: + if handlers.OnPermissionCreated != nil { + handlers.OnPermissionCreated(arg.SrcAddr, arg.DstAddr, arg.Protocol.String(), + arg.Username, arg.Realm, arg.PeerIP) + } + case allocation.OnPermissionDeleted: + if handlers.OnPermissionDeleted != nil { + handlers.OnPermissionDeleted(arg.SrcAddr, arg.DstAddr, arg.Protocol.String(), + arg.Username, arg.Realm, arg.PeerIP) + } + case allocation.OnChannelCreated: + if handlers.OnChannelCreated != nil { + handlers.OnChannelCreated(arg.SrcAddr, arg.DstAddr, arg.Protocol.String(), + arg.Username, arg.Realm, arg.RelayAddr, arg.PeerAddr, arg.ChannelNumber) + } + case allocation.OnChannelDeleted: + if handlers.OnChannelDeleted != nil { + handlers.OnChannelDeleted(arg.SrcAddr, arg.DstAddr, arg.Protocol.String(), + arg.Username, arg.Realm, arg.RelayAddr, arg.PeerAddr, arg.ChannelNumber) + } + default: + } + } +} + // ServerConfig configures the Pion TURN Server type ServerConfig struct { // PacketConnConfigs and ListenerConfigs are a list of all the turn listeners @@ -120,6 +197,9 @@ type ServerConfig struct { // AuthHandler is a callback used to handle incoming auth requests, allowing users to customize Pion TURN with custom behavior AuthHandler AuthHandler + // EventHandlers is a set of callbacks for tracking allocation lifecycle. + EventHandlers EventHandlers + // ChannelBindTimeout sets the lifetime of channel binding. Defaults to 10 minutes. ChannelBindTimeout time.Duration diff --git a/server_test.go b/server_test.go index 44020db6..ab53f69b 100644 --- a/server_test.go +++ b/server_test.go @@ -9,6 +9,7 @@ package turn import ( "fmt" "net" + "sync/atomic" "syscall" "testing" "time" @@ -21,6 +22,13 @@ import ( "github.com/stretchr/testify/assert" ) +const ( + timeout = 200 * time.Millisecond + interval = 50 * time.Millisecond + stunAddr = "1.2.3.4:3478" + turnAddr = "1.2.3.4:3478" +) + func TestServer(t *testing.T) { lim := test.TimeOut(time.Second * 30) defer lim.Stop() @@ -378,7 +386,14 @@ func (v *VNet) Close() error { } func buildVNet() (*VNet, error) { + return buildVNetWithServerEventHandlers(nil) +} + +func buildVNetWithServerEventHandlers(handlers *EventHandlers) (*VNet, error) { loggerFactory := logging.NewDefaultLoggerFactory() + if handlers == nil { + handlers = &EventHandlers{} + } // WAN wan, err := vnet.NewRouter(&vnet.RouterConfig{ @@ -447,7 +462,7 @@ func buildVNet() (*VNet, error) { // Start server... credMap := map[string][]byte{"user": GenerateAuthKey("user", "pion.ly", "pass")} - udpListener, err := net0.ListenPacket("udp4", "0.0.0.0:3478") + udpListener, err := net0.ListenPacket("udp4", "1.2.3.4:3478") if err != nil { return nil, err } @@ -459,7 +474,8 @@ func buildVNet() (*VNet, error) { } return nil, false }, - Realm: "pion.ly", + Realm: "pion.ly", + EventHandlers: *handlers, PacketConnConfigs: []PacketConnConfig{ { PacketConn: udpListener, @@ -498,6 +514,15 @@ func buildVNet() (*VNet, error) { }, nil } +func expectEvent(ch chan allocation.EventHandlerArgs) (allocation.EventHandlerArgs, bool) { + select { + case res := <-ch: + return res, true + case <-time.After(timeout): + return allocation.EventHandlerArgs{}, false + } +} + func TestServerVNet(t *testing.T) { lim := test.TimeOut(time.Second * 30) defer lim.Stop() @@ -521,8 +546,6 @@ func TestServerVNet(t *testing.T) { assert.NoError(t, lconn.Close()) }() - stunAddr := "1.2.3.4:3478" - log.Debug("creating a client.") client, err := NewClient(&ClientConfig{ STUNServerAddr: stunAddr, @@ -544,6 +567,433 @@ func TestServerVNet(t *testing.T) { // to the LAN router. assert.True(t, udpAddr.IP.Equal(net.IPv4(5, 6, 7, 8)), "should match") }) + + t.Run("AllocationLifecycle", func(t *testing.T) { + v, err := buildVNet() + assert.NoError(t, err) + defer func() { + assert.NoError(t, v.Close()) + }() + + // Inject an fake event handler so that we can track the succession of callbacks + events := make(chan allocation.EventHandlerArgs, 5) + defer close(events) + assert.Len(t, v.server.allocationManagers, 1) + v.server.allocationManagers[0].EventHandler = func(arg allocation.EventHandlerArgs) { + log.Info(fmt.Sprintf("%#v", arg)) + events <- arg + } + + lconn, err := v.netL0.ListenPacket("udp4", "0.0.0.0:0") + assert.NoError(t, err, "should succeed") + defer func() { + assert.NoError(t, lconn.Close()) + }() + + log.Debug("creating a client.") + client, err := NewClient(&ClientConfig{ + TURNServerAddr: turnAddr, + Conn: lconn, + Username: "user", + Password: "pass", + Realm: "pion.ly", + LoggerFactory: loggerFactory, + }) + assert.NoError(t, err, "should succeed") + assert.NoError(t, client.Listen(), "should succeed") + defer client.Close() + + log.Debug("sending an allocate request.") + relayConn, err := client.Allocate() + assert.NoError(t, err, "should succeed") + + event, ok := expectEvent(events) + assert.True(t, ok, "should receive an event") + assert.Equal(t, allocation.OnAuth, event.Type, "should receive an OnAuth event") + udpAddr, ok := event.SrcAddr.(*net.UDPAddr) + assert.True(t, ok) + assert.True(t, net.ParseIP("5.6.7.8").Equal(udpAddr.IP)) + udpAddr, ok = event.DstAddr.(*net.UDPAddr) + assert.True(t, ok) + assert.True(t, net.ParseIP("1.2.3.4").Equal(udpAddr.IP)) + assert.Equal(t, allocation.UDP, event.Protocol) + assert.Equal(t, "user", event.Username) + assert.Equal(t, "pion.ly", event.Realm) + assert.Equal(t, "Allocate", event.Method) + assert.True(t, event.Verdict) + + event, ok = expectEvent(events) + assert.True(t, ok, "should receive an event") + assert.Equal(t, allocation.OnAllocationCreated, event.Type, "should receive an OnAllocationCreated event") + udpAddr, ok = event.SrcAddr.(*net.UDPAddr) + assert.True(t, ok) + assert.True(t, net.ParseIP("5.6.7.8").Equal(udpAddr.IP)) + udpAddr, ok = event.DstAddr.(*net.UDPAddr) + assert.True(t, ok) + assert.True(t, net.ParseIP("1.2.3.4").Equal(udpAddr.IP)) + assert.Equal(t, "user", event.Username) + assert.Equal(t, "pion.ly", event.Realm) + assert.Equal(t, 0, event.RequestedPort) + + relayNetAddr := relayConn.LocalAddr() + log.Debugf("relay-address: %s", relayNetAddr.String()) + relayAddr, ok := relayNetAddr.(*net.UDPAddr) + assert.True(t, ok) + udpAddr, ok = event.RelayAddr.(*net.UDPAddr) + assert.True(t, ok) + assert.True(t, relayAddr.IP.Equal(udpAddr.IP)) + assert.Equal(t, relayAddr.Port, udpAddr.Port) + // The transport relay address should have IP address that was assigned to the server. + assert.True(t, udpAddr.IP.Equal(net.IPv4(1, 2, 3, 4)), "should match") + + log.Debug("Sending test packet") + peerAddr := &net.UDPAddr{IP: net.ParseIP("1.2.3.5"), Port: 80} + _, err = relayConn.WriteTo([]byte("test"), peerAddr) + assert.NoError(t, err, "should succeed") + + event, ok = expectEvent(events) + assert.True(t, ok, "should receive an event") + assert.Equal(t, allocation.OnAuth, event.Type, "should receive an OnAuth event") + udpAddr, ok = event.SrcAddr.(*net.UDPAddr) + assert.True(t, ok) + assert.True(t, net.ParseIP("5.6.7.8").Equal(udpAddr.IP)) + udpAddr, ok = event.DstAddr.(*net.UDPAddr) + assert.True(t, ok) + assert.True(t, net.ParseIP("1.2.3.4").Equal(udpAddr.IP)) + assert.Equal(t, allocation.UDP, event.Protocol) + assert.Equal(t, "user", event.Username) + assert.Equal(t, "pion.ly", event.Realm) + assert.Equal(t, "CreatePermission", event.Method) + assert.True(t, event.Verdict) + + event, ok = expectEvent(events) + assert.True(t, ok, "should receive an event") + assert.Equal(t, allocation.OnPermissionCreated, event.Type, "should receive an OnPermissionCreated event") + udpAddr, ok = event.SrcAddr.(*net.UDPAddr) + assert.True(t, ok) + assert.True(t, net.ParseIP("5.6.7.8").Equal(udpAddr.IP)) + udpAddr, ok = event.DstAddr.(*net.UDPAddr) + assert.True(t, ok) + assert.True(t, net.ParseIP("1.2.3.4").Equal(udpAddr.IP)) + assert.Equal(t, allocation.UDP, event.Protocol) + assert.Equal(t, "user", event.Username) + assert.Equal(t, "pion.ly", event.Realm) + assert.True(t, net.ParseIP("1.2.3.5").Equal(event.PeerIP)) + + log.Debug("Forcing the creation of a channel") + _, err = relayConn.WriteTo([]byte("test"), peerAddr) + assert.NoError(t, err, "should succeed") + + event, ok = expectEvent(events) + assert.True(t, ok, "should receive an event") + assert.Equal(t, allocation.OnAuth, event.Type, "should receive an OnAuth event") + udpAddr, ok = event.SrcAddr.(*net.UDPAddr) + assert.True(t, ok) + assert.True(t, net.ParseIP("5.6.7.8").Equal(udpAddr.IP)) + udpAddr, ok = event.DstAddr.(*net.UDPAddr) + assert.True(t, ok) + assert.True(t, net.ParseIP("1.2.3.4").Equal(udpAddr.IP)) + assert.Equal(t, allocation.UDP, event.Protocol) + assert.Equal(t, "user", event.Username) + assert.Equal(t, "pion.ly", event.Realm) + assert.Equal(t, "ChannelBind", event.Method) + assert.True(t, event.Verdict) + + event, ok = expectEvent(events) + assert.True(t, ok, "should receive an event") + assert.Equal(t, allocation.OnChannelCreated, event.Type, "should receive an OnChannelCreated event") + udpAddr, ok = event.SrcAddr.(*net.UDPAddr) + assert.True(t, ok) + assert.True(t, net.ParseIP("5.6.7.8").Equal(udpAddr.IP)) + udpAddr, ok = event.DstAddr.(*net.UDPAddr) + assert.True(t, ok) + assert.True(t, net.ParseIP("1.2.3.4").Equal(udpAddr.IP)) + assert.Equal(t, allocation.UDP, event.Protocol) + assert.Equal(t, "user", event.Username) + assert.Equal(t, "pion.ly", event.Realm) + udpAddr, ok = event.RelayAddr.(*net.UDPAddr) + assert.True(t, ok) + assert.True(t, relayAddr.IP.Equal(udpAddr.IP)) + assert.Equal(t, relayAddr.Port, udpAddr.Port) + + // obtain the channel id + a := v.server.allocationManagers[0].GetAllocation(&allocation.FiveTuple{ + Protocol: allocation.UDP, + SrcAddr: event.SrcAddr, + DstAddr: event.DstAddr, + }) + assert.NotNil(t, a) + channelBind := a.GetChannelByAddr(peerAddr) + assert.NotNil(t, channelBind) + assert.Equal(t, channelBind.Number, proto.ChannelNumber(event.ChannelNumber)) + + log.Debug("Closing relay connection") + assert.NoError(t, relayConn.Close(), "relay conn close should succeed") + + event, ok = expectEvent(events) + assert.True(t, ok, "should receive an event") + assert.Equal(t, allocation.OnAuth, event.Type, "should receive an OnAuth event") + udpAddr, ok = event.SrcAddr.(*net.UDPAddr) + assert.True(t, ok) + assert.True(t, net.ParseIP("5.6.7.8").Equal(udpAddr.IP)) + udpAddr, ok = event.DstAddr.(*net.UDPAddr) + assert.True(t, ok) + assert.True(t, net.ParseIP("1.2.3.4").Equal(udpAddr.IP)) + assert.Equal(t, allocation.UDP, event.Protocol) + assert.Equal(t, "user", event.Username) + assert.Equal(t, "pion.ly", event.Realm) + assert.Equal(t, "Refresh", event.Method) + assert.True(t, event.Verdict) + + event, ok = expectEvent(events) + assert.True(t, ok, "should receive an event") + assert.Equal(t, allocation.OnPermissionDeleted, event.Type, "should receive an OnPermissionDeleted event") + udpAddr, ok = event.SrcAddr.(*net.UDPAddr) + assert.True(t, ok) + assert.True(t, net.ParseIP("5.6.7.8").Equal(udpAddr.IP)) + udpAddr, ok = event.DstAddr.(*net.UDPAddr) + assert.True(t, ok) + assert.True(t, net.ParseIP("1.2.3.4").Equal(udpAddr.IP)) + assert.Equal(t, allocation.UDP, event.Protocol) + assert.Equal(t, "user", event.Username) + assert.Equal(t, "pion.ly", event.Realm) + assert.True(t, net.ParseIP("1.2.3.5").Equal(event.PeerIP)) + + event, ok = expectEvent(events) + assert.True(t, ok, "should receive an event") + assert.Equal(t, allocation.OnChannelDeleted, event.Type, "should receive an OnChannelDeleted event") + udpAddr, ok = event.SrcAddr.(*net.UDPAddr) + assert.True(t, ok) + assert.True(t, net.ParseIP("5.6.7.8").Equal(udpAddr.IP)) + udpAddr, ok = event.DstAddr.(*net.UDPAddr) + assert.True(t, ok) + assert.True(t, net.ParseIP("1.2.3.4").Equal(udpAddr.IP)) + assert.Equal(t, allocation.UDP, event.Protocol) + assert.Equal(t, "user", event.Username) + assert.Equal(t, "pion.ly", event.Realm) + udpAddr, ok = event.RelayAddr.(*net.UDPAddr) + assert.True(t, ok) + assert.True(t, relayAddr.IP.Equal(udpAddr.IP)) + assert.Equal(t, relayAddr.Port, udpAddr.Port) + assert.Equal(t, channelBind.Number, proto.ChannelNumber(event.ChannelNumber)) + + event, ok = expectEvent(events) + assert.True(t, ok, "should receive an event") + assert.Equal(t, allocation.OnAllocationDeleted, event.Type, "should receive an OnAllocationDeleted event") + udpAddr, ok = event.SrcAddr.(*net.UDPAddr) + assert.True(t, ok) + assert.True(t, net.ParseIP("5.6.7.8").Equal(udpAddr.IP)) + udpAddr, ok = event.DstAddr.(*net.UDPAddr) + assert.True(t, ok) + assert.True(t, net.ParseIP("1.2.3.4").Equal(udpAddr.IP)) + assert.Equal(t, allocation.UDP, event.Protocol) + assert.Equal(t, "user", event.Username) + assert.Equal(t, "pion.ly", event.Realm) + }) + + checkAllocation := func(srcAddr, dstAddr net.Addr, protocol, username, realm string) { + udpAddr, ok := srcAddr.(*net.UDPAddr) + assert.True(t, ok) + assert.True(t, net.ParseIP("5.6.7.8").Equal(udpAddr.IP)) + udpAddr, ok = dstAddr.(*net.UDPAddr) + assert.True(t, ok) + assert.True(t, net.ParseIP("1.2.3.4").Equal(udpAddr.IP)) + assert.Equal(t, allocation.UDP.String(), protocol) + assert.Equal(t, "user", username) + assert.Equal(t, "pion.ly", realm) + } + authEventHandler := func(expectedVerdict bool) (*EventHandlers, *atomic.Int32) { + counter := &atomic.Int32{} + return &EventHandlers{ + OnAuth: func(srcAddr, dstAddr net.Addr, protocol, username, realm string, method string, verdict bool) { + checkAllocation(srcAddr, dstAddr, protocol, username, realm) + assert.True(t, method == "Allocate" || method == "Refresh") // close calls refresh with 0 lifetime + assert.Equal(t, expectedVerdict, verdict) + counter.Add(1) + }, + }, counter + } + + t.Run("AuthEventHandlerSuccess", func(t *testing.T) { + authCallback, counter := authEventHandler(true) + v, err := buildVNetWithServerEventHandlers(authCallback) + assert.NoError(t, err) + defer func() { + assert.NoError(t, v.Close()) + }() + + lconn, err := v.netL0.ListenPacket("udp4", "0.0.0.0:0") + assert.NoError(t, err, "should succeed") + defer func() { + assert.NoError(t, lconn.Close()) + }() + + log.Debug("creating a client.") + client, err := NewClient(&ClientConfig{ + TURNServerAddr: turnAddr, + Conn: lconn, + Username: "user", + Password: "pass", + Realm: "pion.ly", + LoggerFactory: loggerFactory, + }) + assert.NoError(t, err, "should succeed") + assert.NoError(t, client.Listen(), "should succeed") + defer client.Close() + + log.Debug("sending an allocate request.") + relayConn, err := client.Allocate() + assert.NoError(t, err, "should succeed") + + log.Debug("Closing relay connection") + assert.NoError(t, relayConn.Close(), "relay conn close should succeed") + + assert.Eventually(t, func() bool { return counter.Load() == 2 }, timeout, interval) + }) + + t.Run("AuthEventHandlerFailure", func(t *testing.T) { + authCallback, counter := authEventHandler(false) + v, err := buildVNetWithServerEventHandlers(authCallback) + assert.NoError(t, err) + defer func() { + assert.NoError(t, v.Close()) + }() + + lconn, err := v.netL0.ListenPacket("udp4", "0.0.0.0:0") + assert.NoError(t, err, "should succeed") + defer func() { + assert.NoError(t, lconn.Close()) + }() + + log.Debug("creating a client.") + client, err := NewClient(&ClientConfig{ + TURNServerAddr: turnAddr, + Conn: lconn, + Username: "user", + Password: "wrong-pass", + Realm: "pion.ly", + LoggerFactory: loggerFactory, + }) + assert.NoError(t, err, "should succeed") + assert.NoError(t, client.Listen(), "should succeed") + defer client.Close() + + log.Debug("sending an allocate request.") + _, err = client.Allocate() + assert.Error(t, err, "should not succeed") + + assert.Eventually(t, func() bool { return counter.Load() == 1 }, timeout, interval) + }) + + t.Run("AllocationEventHandlers", func(t *testing.T) { + peerAddr := &net.UDPAddr{IP: net.ParseIP("1.2.3.5"), Port: 80} + relayAddrIP := net.ParseIP("1.2.3.4") + allocCreated, allocDeleted := &atomic.Int32{}, &atomic.Int32{} + permissionCreated, permissionDeleted := &atomic.Int32{}, &atomic.Int32{} + channelCreated, channelDeleted := &atomic.Int32{}, &atomic.Int32{} + allocCallback := &EventHandlers{ + OnAllocationCreated: func(srcAddr, dstAddr net.Addr, protocol, username, realm string, relayAddr net.Addr, requestedPort int) { + checkAllocation(srcAddr, dstAddr, protocol, username, realm) + assert.Equal(t, 0, requestedPort) + udpAddr, ok := relayAddr.(*net.UDPAddr) + assert.True(t, ok) + assert.True(t, relayAddrIP.Equal(udpAddr.IP)) + allocCreated.Add(1) + }, + OnAllocationDeleted: func(srcAddr, dstAddr net.Addr, protocol, username, realm string) { + checkAllocation(srcAddr, dstAddr, protocol, username, realm) + allocDeleted.Add(1) + }, + OnPermissionCreated: func(srcAddr, dstAddr net.Addr, protocol, username, realm string, peer net.IP) { + checkAllocation(srcAddr, dstAddr, protocol, username, realm) + assert.True(t, net.ParseIP("1.2.3.5").Equal(peer)) + permissionCreated.Add(1) + }, + OnPermissionDeleted: func(srcAddr, dstAddr net.Addr, protocol, username, realm string, peer net.IP) { + checkAllocation(srcAddr, dstAddr, protocol, username, realm) + assert.True(t, net.ParseIP("1.2.3.5").Equal(peer)) + permissionDeleted.Add(1) + }, + OnChannelCreated: func(srcAddr, dstAddr net.Addr, protocol, username, realm string, relayAddr, peer net.Addr, channelNumber uint16) { + checkAllocation(srcAddr, dstAddr, protocol, username, realm) + addr, ok := peer.(*net.UDPAddr) + assert.True(t, ok) + assert.True(t, addr.IP.Equal(peerAddr.IP)) + assert.Equal(t, peerAddr.Port, addr.Port) + udpAddr, ok := relayAddr.(*net.UDPAddr) + assert.True(t, ok) + assert.True(t, relayAddrIP.Equal(udpAddr.IP)) + assert.NotZero(t, channelNumber) + channelCreated.Add(1) + }, + OnChannelDeleted: func(srcAddr, dstAddr net.Addr, protocol, username, realm string, relayAddr, peer net.Addr, channelNumber uint16) { + checkAllocation(srcAddr, dstAddr, protocol, username, realm) + addr, ok := peer.(*net.UDPAddr) + assert.True(t, ok) + assert.True(t, addr.IP.Equal(peerAddr.IP)) + assert.Equal(t, peerAddr.Port, addr.Port) + udpAddr, ok := relayAddr.(*net.UDPAddr) + assert.True(t, ok) + assert.True(t, relayAddrIP.Equal(udpAddr.IP)) + assert.NotZero(t, channelNumber) + channelDeleted.Add(1) + }, + } + + v, err := buildVNetWithServerEventHandlers(allocCallback) + assert.NoError(t, err) + defer func() { + assert.NoError(t, v.Close()) + }() + + lconn, err := v.netL0.ListenPacket("udp4", "0.0.0.0:0") + assert.NoError(t, err, "should succeed") + defer func() { + assert.NoError(t, lconn.Close()) + }() + + log.Debug("creating a client.") + client, err := NewClient(&ClientConfig{ + TURNServerAddr: turnAddr, + Conn: lconn, + Username: "user", + Password: "pass", + Realm: "pion.ly", + LoggerFactory: loggerFactory, + }) + assert.NoError(t, err, "should succeed") + assert.NoError(t, client.Listen(), "should succeed") + defer client.Close() + + log.Debug("sending an allocate request.") + relayConn, err := client.Allocate() + assert.NoError(t, err, "should succeed") + + assert.Eventually(t, func() bool { return allocCreated.Load() == 1 }, timeout, interval) + + log.Debug("Sending test packet") + _, err = relayConn.WriteTo([]byte("test"), peerAddr) + assert.NoError(t, err, "should succeed") + + assert.Eventually(t, func() bool { return permissionCreated.Load() == 1 }, timeout, interval) + + log.Debug("Forcing the creation of a channel") + _, err = relayConn.WriteTo([]byte("test"), peerAddr) + assert.NoError(t, err, "should succeed") + + assert.Eventually(t, func() bool { return channelCreated.Load() == 1 }, timeout, interval) + + log.Debug("Closing relay connection") + assert.NoError(t, relayConn.Close(), "relay conn close should succeed") + + assert.Eventually(t, func() bool { return permissionDeleted.Load() == 1 }, timeout, interval) + assert.Eventually(t, func() bool { return allocCreated.Load() == 1 }, timeout, interval) + assert.Eventually(t, func() bool { return allocDeleted.Load() == 1 }, timeout, interval) + assert.Eventually(t, func() bool { return permissionCreated.Load() == 1 }, timeout, interval) + assert.Eventually(t, func() bool { return permissionDeleted.Load() == 1 }, timeout, interval) + assert.Eventually(t, func() bool { return channelCreated.Load() == 1 }, timeout, interval) + assert.Eventually(t, func() bool { return channelDeleted.Load() == 1 }, timeout, interval) + }) } func TestConsumeSingleTURNFrame(t *testing.T) {