Skip to content

Commit

Permalink
swarm: wait for transient connections to upgrade for NewStream
Browse files Browse the repository at this point in the history
  • Loading branch information
sukunrt committed Aug 31, 2023
1 parent a29a92e commit 6b1cf9e
Show file tree
Hide file tree
Showing 2 changed files with 79 additions and 19 deletions.
91 changes: 74 additions & 17 deletions p2p/net/swarm/swarm.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ import (
"github.com/libp2p/go-libp2p/core/peer"
"github.com/libp2p/go-libp2p/core/peerstore"
"github.com/libp2p/go-libp2p/core/transport"
"golang.org/x/exp/slices"

logging "github.com/ipfs/go-log/v2"
ma "github.com/multiformats/go-multiaddr"
Expand Down Expand Up @@ -169,7 +170,8 @@ type Swarm struct {

notifs struct {
sync.RWMutex
m map[network.Notifiee]struct{}
m map[network.Notifiee]struct{}
directConn map[peer.ID][]chan struct{}
}

transports struct {
Expand Down Expand Up @@ -230,7 +232,9 @@ func NewSwarm(local peer.ID, peers peerstore.Peerstore, eventBus event.Bus, opts
s.conns.m = make(map[peer.ID][]*Conn)
s.listeners.m = make(map[transport.Listener]struct{})
s.transports.m = make(map[int]transport.Transport)

s.notifs.m = make(map[network.Notifiee]struct{})
s.notifs.directConn = make(map[peer.ID][]chan struct{})

for _, opt := range opts {
if err := opt(s); err != nil {
Expand Down Expand Up @@ -390,6 +394,16 @@ func (s *Swarm) addConn(tc transport.CapableConn, dir network.Direction) (*Conn,
c.notifyLk.Lock()
s.conns.Unlock()

// Notify goroutines waiting for a direct connection
s.notifs.Lock()
if !c.Stat().Transient {
for _, ch := range s.notifs.directConn[p] {
close(ch)
}
delete(s.notifs.directConn, p)
}
s.notifs.Unlock()

// Emit event after releasing `s.conns` lock so that a consumer can still
// use swarm methods that need the `s.conns` lock.
if isFirstConnection {
Expand Down Expand Up @@ -441,40 +455,83 @@ func (s *Swarm) NewStream(ctx context.Context, p peer.ID) (network.Stream, error
// connection but don't notice it until we actually try to open a
// stream.
//
// Note: We only dial once.
//
// TODO: Try all connections even if we get an error opening a stream on
// a non-closed connection.
dials := 0
for {
// will prefer direct connections over relayed connections for opening streams
c := s.bestAcceptableConnToPeer(ctx, p)

dialed := false
for i := 0; i < 1; i++ {
c := s.bestConnToPeer(p)
if c == nil {
if nodial, _ := network.GetNoDial(ctx); nodial {
if nodial, _ := network.GetNoDial(ctx); !nodial {
if dialed {
return nil, errors.New("max dial attempts exceeded")
}
dialed = true
var err error
c, err = s.dialPeer(ctx, p)
if err != nil {
return nil, err
}
} else {
return nil, network.ErrNoConn
}
}

if dials >= DialAttempts {
return nil, errors.New("max dial attempts exceeded")
}
dials++

useTransient, _ := network.GetUseTransient(ctx)
if !useTransient && c.Stat().Transient {
var err error
c, err = s.dialPeer(ctx, p)
c, err = s.waitForTransientUpgrade(ctx, p)
if err != nil {
return nil, err
}
}

s, err := c.NewStream(ctx)
str, err := c.NewStream(ctx)
if err != nil {
if c.conn.IsClosed() {
continue
}
return nil, err
}
return s, nil
return str, nil
}
return nil, network.ErrNoConn
}

// waitForTransientUpgrade waits for a direct connection established through DCUtR or connection reversal.
func (s *Swarm) waitForTransientUpgrade(ctx context.Context, p peer.ID) (*Conn, error) {
s.notifs.Lock()
c := s.bestConnToPeer(p)
if c == nil {
s.notifs.Unlock()
return nil, network.ErrNoConn
} else if !c.Stat().Transient {
s.notifs.Unlock()
return c, nil
}
// Wait for transient connection to upgrade to a direct connection either via connection reversal
// or hole punching.
ch := make(chan struct{})
s.notifs.directConn[p] = append(s.notifs.directConn[p], ch)
s.notifs.Unlock()

// Wait for notification.
// There's no point waiting for more than a minute here.
ctx, cancel := context.WithTimeout(ctx, time.Minute)
defer cancel()
select {
case <-ctx.Done():
s.notifs.Lock()
// Remove ourselves from the notification list
slices.DeleteFunc(s.notifs.directConn[p], func(c chan struct{}) bool { return c == ch })
if len(s.notifs.directConn[p]) == 0 {
delete(s.notifs.directConn, p)
}
s.notifs.Unlock()
return nil, ctx.Err()
case <-ch:
// We do not need to remove ourselves from the list here as the notifier would have cleared the map
// entry
return s.bestConnToPeer(p), nil
}
}

Expand Down
7 changes: 5 additions & 2 deletions p2p/test/basichost/basic_host_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"context"
"fmt"
"testing"
"time"

"github.com/libp2p/go-libp2p"
"github.com/libp2p/go-libp2p/core/network"
Expand Down Expand Up @@ -62,10 +63,12 @@ func TestNoStreamOverTransientConnection(t *testing.T) {
err = h1.Connect(context.Background(), h2Info)
require.NoError(t, err)

ctx := network.WithNoDial(context.Background(), "test")
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
defer cancel()
ctx = network.WithNoDial(ctx, "test")
_, err = h1.NewStream(ctx, h2.ID(), "/testprotocol")

require.ErrorIs(t, err, network.ErrTransientConn)
require.Error(t, err)

_, err = h1.NewStream(network.WithUseTransient(context.Background(), "test"), h2.ID(), "/testprotocol")
require.NoError(t, err)
Expand Down

0 comments on commit 6b1cf9e

Please sign in to comment.