diff --git a/p2p/net/swarm/swarm.go b/p2p/net/swarm/swarm.go index d26b9e781e..a51ed99018 100644 --- a/p2p/net/swarm/swarm.go +++ b/p2p/net/swarm/swarm.go @@ -17,6 +17,7 @@ import ( "github.com/libp2p/go-libp2p/core/peer" "github.com/libp2p/go-libp2p/core/peerstore" "github.com/libp2p/go-libp2p/core/transport" + "golang.org/x/exp/slices" logging "github.com/ipfs/go-log/v2" ma "github.com/multiformats/go-multiaddr" @@ -169,7 +170,8 @@ type Swarm struct { notifs struct { sync.RWMutex - m map[network.Notifiee]struct{} + m map[network.Notifiee]struct{} + directConn map[peer.ID][]chan struct{} } transports struct { @@ -230,7 +232,9 @@ func NewSwarm(local peer.ID, peers peerstore.Peerstore, eventBus event.Bus, opts s.conns.m = make(map[peer.ID][]*Conn) s.listeners.m = make(map[transport.Listener]struct{}) s.transports.m = make(map[int]transport.Transport) + s.notifs.m = make(map[network.Notifiee]struct{}) + s.notifs.directConn = make(map[peer.ID][]chan struct{}) for _, opt := range opts { if err := opt(s); err != nil { @@ -390,6 +394,16 @@ func (s *Swarm) addConn(tc transport.CapableConn, dir network.Direction) (*Conn, c.notifyLk.Lock() s.conns.Unlock() + // Notify goroutines waiting for a direct connection + s.notifs.Lock() + if !c.Stat().Transient { + for _, ch := range s.notifs.directConn[p] { + close(ch) + } + delete(s.notifs.directConn, p) + } + s.notifs.Unlock() + // Emit event after releasing `s.conns` lock so that a consumer can still // use swarm methods that need the `s.conns` lock. if isFirstConnection { @@ -441,40 +455,83 @@ func (s *Swarm) NewStream(ctx context.Context, p peer.ID) (network.Stream, error // connection but don't notice it until we actually try to open a // stream. // - // Note: We only dial once. - // // TODO: Try all connections even if we get an error opening a stream on // a non-closed connection. - dials := 0 - for { - // will prefer direct connections over relayed connections for opening streams - c := s.bestAcceptableConnToPeer(ctx, p) - + dialed := false + for i := 0; i < 1; i++ { + c := s.bestConnToPeer(p) if c == nil { - if nodial, _ := network.GetNoDial(ctx); nodial { + if nodial, _ := network.GetNoDial(ctx); !nodial { + if dialed { + return nil, errors.New("max dial attempts exceeded") + } + dialed = true + var err error + c, err = s.dialPeer(ctx, p) + if err != nil { + return nil, err + } + } else { return nil, network.ErrNoConn } + } - if dials >= DialAttempts { - return nil, errors.New("max dial attempts exceeded") - } - dials++ - + useTransient, _ := network.GetUseTransient(ctx) + if !useTransient && c.Stat().Transient { var err error - c, err = s.dialPeer(ctx, p) + c, err = s.waitForTransientUpgrade(ctx, p) if err != nil { return nil, err } } - s, err := c.NewStream(ctx) + str, err := c.NewStream(ctx) if err != nil { if c.conn.IsClosed() { continue } return nil, err } - return s, nil + return str, nil + } + return nil, network.ErrNoConn +} + +// waitForTransientUpgrade waits for a direct connection established through DCUtR or connection reversal. +func (s *Swarm) waitForTransientUpgrade(ctx context.Context, p peer.ID) (*Conn, error) { + s.notifs.Lock() + c := s.bestConnToPeer(p) + if c == nil { + s.notifs.Unlock() + return nil, network.ErrNoConn + } else if !c.Stat().Transient { + s.notifs.Unlock() + return c, nil + } + // Wait for transient connection to upgrade to a direct connection either via connection reversal + // or hole punching. + ch := make(chan struct{}) + s.notifs.directConn[p] = append(s.notifs.directConn[p], ch) + s.notifs.Unlock() + + // Wait for notification. + // There's no point waiting for more than a minute here. + ctx, cancel := context.WithTimeout(ctx, time.Minute) + defer cancel() + select { + case <-ctx.Done(): + s.notifs.Lock() + // Remove ourselves from the notification list + slices.DeleteFunc(s.notifs.directConn[p], func(c chan struct{}) bool { return c == ch }) + if len(s.notifs.directConn[p]) == 0 { + delete(s.notifs.directConn, p) + } + s.notifs.Unlock() + return nil, ctx.Err() + case <-ch: + // We do not need to remove ourselves from the list here as the notifier would have cleared the map + // entry + return s.bestConnToPeer(p), nil } } diff --git a/p2p/test/basichost/basic_host_test.go b/p2p/test/basichost/basic_host_test.go index 6b010ed2aa..f146872261 100644 --- a/p2p/test/basichost/basic_host_test.go +++ b/p2p/test/basichost/basic_host_test.go @@ -4,6 +4,7 @@ import ( "context" "fmt" "testing" + "time" "github.com/libp2p/go-libp2p" "github.com/libp2p/go-libp2p/core/network" @@ -62,10 +63,12 @@ func TestNoStreamOverTransientConnection(t *testing.T) { err = h1.Connect(context.Background(), h2Info) require.NoError(t, err) - ctx := network.WithNoDial(context.Background(), "test") + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + ctx = network.WithNoDial(ctx, "test") _, err = h1.NewStream(ctx, h2.ID(), "/testprotocol") - require.ErrorIs(t, err, network.ErrTransientConn) + require.Error(t, err) _, err = h1.NewStream(network.WithUseTransient(context.Background(), "test"), h2.ID(), "/testprotocol") require.NoError(t, err)