Skip to content

Commit

Permalink
swarm: fix DialPeer behaviour for transient connections (#2547)
Browse files Browse the repository at this point in the history
  • Loading branch information
sukunrt committed Oct 18, 2023
1 parent 91c432f commit cd9e7cb
Show file tree
Hide file tree
Showing 4 changed files with 89 additions and 29 deletions.
8 changes: 4 additions & 4 deletions p2p/net/swarm/dial_worker.go
Original file line number Diff line number Diff line change
Expand Up @@ -169,9 +169,9 @@ loop:
// Enqueue the peer's addresses relevant to this request in dq and
// track dials to the addresses relevant to this request.

c, err := w.s.bestAcceptableConnToPeer(req.ctx, w.peer)
if c != nil || err != nil {
req.resch <- dialResponse{conn: c, err: err}
c := w.s.bestAcceptableConnToPeer(req.ctx, w.peer)
if c != nil {
req.resch <- dialResponse{conn: c}
continue loop
}

Expand Down Expand Up @@ -396,7 +396,7 @@ func (w *dialWorker) dispatchError(ad *addrDial, err error) {
// all addrs have erred, dispatch dial error
// but first do a last one check in case an acceptable connection has landed from
// a simultaneous dial that started later and added new acceptable addrs
c, _ := w.s.bestAcceptableConnToPeer(pr.req.ctx, w.peer)
c := w.s.bestAcceptableConnToPeer(pr.req.ctx, w.peer)
if c != nil {
pr.req.resch <- dialResponse{conn: c}
} else {
Expand Down
28 changes: 9 additions & 19 deletions p2p/net/swarm/swarm.go
Original file line number Diff line number Diff line change
Expand Up @@ -429,6 +429,8 @@ func (s *Swarm) StreamHandler() network.StreamHandler {

// NewStream creates a new stream on any available connection to peer, dialing
// if necessary.
// Use network.WithUseTransient to open a stream over a transient(relayed)
// connection.
func (s *Swarm) NewStream(ctx context.Context, p peer.ID) (network.Stream, error) {
log.Debugf("[%s] opening stream to peer [%s]", s.local, p)

Expand All @@ -447,10 +449,7 @@ func (s *Swarm) NewStream(ctx context.Context, p peer.ID) (network.Stream, error
dials := 0
for {
// will prefer direct connections over relayed connections for opening streams
c, err := s.bestAcceptableConnToPeer(ctx, p)
if err != nil {
return nil, err
}
c := s.bestAcceptableConnToPeer(ctx, p)

if c == nil {
if nodial, _ := network.GetNoDial(ctx); nodial {
Expand Down Expand Up @@ -548,26 +547,17 @@ func (s *Swarm) bestConnToPeer(p peer.ID) *Conn {
return best
}

// - Returns the best "acceptable" connection, if available.
// - Returns nothing if no such connection exists, but if we should try dialing anyways.
// - Returns an error if no such connection exists, but we should not try dialing.
func (s *Swarm) bestAcceptableConnToPeer(ctx context.Context, p peer.ID) (*Conn, error) {
// bestAcceptableConnToPeer returns the best acceptable connection, considering the passed in ctx.
// If network.WithForceDirectDial is used, it only returns a direct connections, ignoring
// any transient (relayed) connections to the peer.
func (s *Swarm) bestAcceptableConnToPeer(ctx context.Context, p peer.ID) *Conn {
conn := s.bestConnToPeer(p)
if conn == nil {
return nil, nil
}

forceDirect, _ := network.GetForceDirectDial(ctx)
if forceDirect && !isDirectConn(conn) {
return nil, nil
}

useTransient, _ := network.GetUseTransient(ctx)
if useTransient || !conn.Stat().Transient {
return conn, nil
return nil
}

return nil, network.ErrTransientConn
return conn
}

func isDirectConn(c *Conn) bool {
Expand Down
12 changes: 6 additions & 6 deletions p2p/net/swarm/swarm_dial.go
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,8 @@ func (db *DialBackoff) cleanup() {
}
}

// DialPeer connects to a peer.
// DialPeer connects to a peer. Use network.WithForceDirectDial to force a
// direct connection.
//
// The idea is that the client of Swarm does not need to know what network
// the connection will happen over. Swarm can use whichever it choses.
Expand Down Expand Up @@ -246,11 +247,10 @@ func (s *Swarm) dialPeer(ctx context.Context, p peer.ID) (*Conn, error) {
return nil, ErrDialToSelf
}

// check if we already have an open (usable) connection first, or can't have a usable
// connection.
conn, err := s.bestAcceptableConnToPeer(ctx, p)
if conn != nil || err != nil {
return conn, err
// check if we already have an open (usable) connection.
conn := s.bestAcceptableConnToPeer(ctx, p)
if conn != nil {
return conn, nil
}

if s.gater != nil && !s.gater.InterceptPeerDial(p) {
Expand Down
70 changes: 70 additions & 0 deletions p2p/test/swarm/swarm_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
package swarm_test

import (
"context"
"testing"

"github.com/libp2p/go-libp2p"
"github.com/libp2p/go-libp2p/core/network"
"github.com/libp2p/go-libp2p/core/peer"
"github.com/libp2p/go-libp2p/core/peerstore"
"github.com/libp2p/go-libp2p/p2p/protocol/circuitv2/client"
"github.com/libp2p/go-libp2p/p2p/protocol/circuitv2/relay"
ma "github.com/multiformats/go-multiaddr"
"github.com/stretchr/testify/require"
)

func TestDialPeerTransientConnection(t *testing.T) {
h1, err := libp2p.New(
libp2p.NoListenAddrs,
libp2p.EnableRelay(),
)
require.NoError(t, err)

h2, err := libp2p.New(
libp2p.NoListenAddrs,
libp2p.EnableRelay(),
)
require.NoError(t, err)

relay1, err := libp2p.New()
require.NoError(t, err)

_, err = relay.New(relay1)
require.NoError(t, err)

relay1info := peer.AddrInfo{
ID: relay1.ID(),
Addrs: relay1.Addrs(),
}
err = h1.Connect(context.Background(), relay1info)
require.NoError(t, err)

err = h2.Connect(context.Background(), relay1info)
require.NoError(t, err)

_, err = client.Reserve(context.Background(), h2, relay1info)
require.NoError(t, err)

relayaddr := ma.StringCast("/p2p/" + relay1info.ID.String() + "/p2p-circuit/p2p/" + h2.ID().String())

h1.Peerstore().AddAddr(h2.ID(), relayaddr, peerstore.TempAddrTTL)

// swarm.DialPeer should connect over transient connections
conn1, err := h1.Network().DialPeer(context.Background(), h2.ID())
require.NoError(t, err)
require.NotNil(t, conn1)

// Test that repeated calls return the same connection.
conn2, err := h1.Network().DialPeer(context.Background(), h2.ID())
require.NoError(t, err)
require.NotNil(t, conn2)

require.Equal(t, conn1, conn2)

// swarm.DialPeer should fail if forceDirect is used
ctx := network.WithForceDirectDial(context.Background(), "test")
conn, err := h1.Network().DialPeer(ctx, h2.ID())
require.Error(t, err)
require.Nil(t, conn)
}

0 comments on commit cd9e7cb

Please sign in to comment.