diff --git a/p2p/net/swarm/swarm.go b/p2p/net/swarm/swarm.go index 30dec7b3ce..a51ed99018 100644 --- a/p2p/net/swarm/swarm.go +++ b/p2p/net/swarm/swarm.go @@ -170,8 +170,8 @@ type Swarm struct { notifs struct { sync.RWMutex - m map[network.Notifiee]struct{} - waiters map[peer.ID][]chan struct{} + m map[network.Notifiee]struct{} + directConn map[peer.ID][]chan struct{} } transports struct { @@ -232,8 +232,9 @@ func NewSwarm(local peer.ID, peers peerstore.Peerstore, eventBus event.Bus, opts s.conns.m = make(map[peer.ID][]*Conn) s.listeners.m = make(map[transport.Listener]struct{}) s.transports.m = make(map[int]transport.Transport) + s.notifs.m = make(map[network.Notifiee]struct{}) - s.notifs.waiters = make(map[peer.ID][]chan struct{}) + s.notifs.directConn = make(map[peer.ID][]chan struct{}) for _, opt := range opts { if err := opt(s); err != nil { @@ -394,12 +395,14 @@ func (s *Swarm) addConn(tc transport.CapableConn, dir network.Direction) (*Conn, s.conns.Unlock() // Notify goroutines waiting for a direct connection + s.notifs.Lock() if !c.Stat().Transient { - for _, ch := range s.notifs.waiters[p] { + for _, ch := range s.notifs.directConn[p] { close(ch) } - delete(s.notifs.waiters, p) + delete(s.notifs.directConn, p) } + s.notifs.Unlock() // Emit event after releasing `s.conns` lock so that a consumer can still // use swarm methods that need the `s.conns` lock. @@ -471,13 +474,12 @@ func (s *Swarm) NewStream(ctx context.Context, p peer.ID) (network.Stream, error } else { return nil, network.ErrNoConn } - } useTransient, _ := network.GetUseTransient(ctx) if !useTransient && c.Stat().Transient { var err error - c, err = s.waitAndGetDirectConn(ctx, p) + c, err = s.waitForTransientUpgrade(ctx, p) if err != nil { return nil, err } @@ -495,30 +497,42 @@ func (s *Swarm) NewStream(ctx context.Context, p peer.ID) (network.Stream, error return nil, network.ErrNoConn } -// waitAndGetDirectConn waits for a direct connection till the context finishes or we get -// a direct connection through DCUtR or connection reversal. -func (s *Swarm) waitAndGetDirectConn(ctx context.Context, p peer.ID) (*Conn, error) { +// waitForTransientUpgrade waits for a direct connection established through DCUtR or connection reversal. +func (s *Swarm) waitForTransientUpgrade(ctx context.Context, p peer.ID) (*Conn, error) { s.notifs.Lock() c := s.bestConnToPeer(p) - if c == nil || c.Stat().Transient { - ch := make(chan struct{}) - s.notifs.waiters[p] = append(s.notifs.waiters[p], ch) + if c == nil { s.notifs.Unlock() - select { - case <-ctx.Done(): - s.notifs.Lock() - slices.DeleteFunc(s.notifs.waiters[p], func(c chan struct{}) bool { return c == ch }) - if len(s.notifs.waiters[p]) == 0 { - delete(s.notifs.waiters, p) - } - s.notifs.Unlock() - return nil, ctx.Err() - case <-ch: - return s.bestConnToPeer(p), nil - } + return nil, network.ErrNoConn + } else if !c.Stat().Transient { + s.notifs.Unlock() + return c, nil } + // Wait for transient connection to upgrade to a direct connection either via connection reversal + // or hole punching. + ch := make(chan struct{}) + s.notifs.directConn[p] = append(s.notifs.directConn[p], ch) s.notifs.Unlock() - return c, nil + + // Wait for notification. + // There's no point waiting for more than a minute here. + ctx, cancel := context.WithTimeout(ctx, time.Minute) + defer cancel() + select { + case <-ctx.Done(): + s.notifs.Lock() + // Remove ourselves from the notification list + slices.DeleteFunc(s.notifs.directConn[p], func(c chan struct{}) bool { return c == ch }) + if len(s.notifs.directConn[p]) == 0 { + delete(s.notifs.directConn, p) + } + s.notifs.Unlock() + return nil, ctx.Err() + case <-ch: + // We do not need to remove ourselves from the list here as the notifier would have cleared the map + // entry + return s.bestConnToPeer(p), nil + } } // ConnsToPeer returns all the live connections to peer.