From db8eba8bbaf13126b549351bac5d9b09b8bf1d9d Mon Sep 17 00:00:00 2001 From: sukun Date: Fri, 1 Sep 2023 14:51:54 +0530 Subject: [PATCH 1/3] swarm: wait for transient connections to upgrade for NewStream --- p2p/net/swarm/swarm.go | 111 +++++++++++++++++++++----- p2p/test/basichost/basic_host_test.go | 77 +++++++++++++++++- p2p/test/swarm/swarm_test.go | 63 +++++++++++++++ 3 files changed, 231 insertions(+), 20 deletions(-) diff --git a/p2p/net/swarm/swarm.go b/p2p/net/swarm/swarm.go index 57361e8d59..b40c49d36e 100644 --- a/p2p/net/swarm/swarm.go +++ b/p2p/net/swarm/swarm.go @@ -17,6 +17,7 @@ import ( "github.com/libp2p/go-libp2p/core/peer" "github.com/libp2p/go-libp2p/core/peerstore" "github.com/libp2p/go-libp2p/core/transport" + "golang.org/x/exp/slices" logging "github.com/ipfs/go-log/v2" ma "github.com/multiformats/go-multiaddr" @@ -172,6 +173,11 @@ type Swarm struct { m map[network.Notifiee]struct{} } + directConnNotifs struct { + sync.Mutex + m map[peer.ID][]chan struct{} + } + transports struct { sync.RWMutex m map[int]transport.Transport @@ -231,6 +237,7 @@ func NewSwarm(local peer.ID, peers peerstore.Peerstore, eventBus event.Bus, opts s.listeners.m = make(map[transport.Listener]struct{}) s.transports.m = make(map[int]transport.Transport) s.notifs.m = make(map[network.Notifiee]struct{}) + s.directConnNotifs.m = make(map[peer.ID][]chan struct{}) for _, opt := range opts { if err := opt(s); err != nil { @@ -390,6 +397,19 @@ func (s *Swarm) addConn(tc transport.CapableConn, dir network.Direction) (*Conn, c.notifyLk.Lock() s.conns.Unlock() + // Notify goroutines waiting for a direct connection + + // Go routines interested in waiting for direct connection first acquire this lock and then + // acquire conns.RLock. Do not acquire this lock before conns.Unlock to prevent deadlock. + s.directConnNotifs.Lock() + if !c.Stat().Transient { + for _, ch := range s.directConnNotifs.m[p] { + close(ch) + } + delete(s.directConnNotifs.m, p) + } + s.directConnNotifs.Unlock() + // Emit event after releasing `s.conns` lock so that a consumer can still // use swarm methods that need the `s.conns` lock. if isFirstConnection { @@ -436,46 +456,101 @@ func (s *Swarm) NewStream(ctx context.Context, p peer.ID) (network.Stream, error // Algorithm: // 1. Find the best connection, otherwise, dial. - // 2. Try opening a stream. - // 3. If the underlying connection is, in fact, closed, close the outer + // 2. If the best connection is transient, wait for a direct conn via conn + // reversal or hole punching. + // 3. Try opening a stream. + // 4. If the underlying connection is, in fact, closed, close the outer // connection and try again. We do this in case we have a closed // connection but don't notice it until we actually try to open a // stream. // - // Note: We only dial once. - // // TODO: Try all connections even if we get an error opening a stream on // a non-closed connection. - dials := 0 - for { - // will prefer direct connections over relayed connections for opening streams - c := s.bestAcceptableConnToPeer(ctx, p) - + dialed := false + for i := 0; i < 1; i++ { + c := s.bestConnToPeer(p) if c == nil { - if nodial, _ := network.GetNoDial(ctx); nodial { + if nodial, _ := network.GetNoDial(ctx); !nodial { + if dialed { + return nil, errors.New("max dial attempts exceeded") + } + dialed = true + var err error + c, err = s.dialPeer(ctx, p) + if err != nil { + return nil, err + } + } else { return nil, network.ErrNoConn } + } - if dials >= DialAttempts { - return nil, errors.New("max dial attempts exceeded") - } - dials++ - + useTransient, _ := network.GetUseTransient(ctx) + if !useTransient && c.Stat().Transient { var err error - c, err = s.dialPeer(ctx, p) + c, err = s.waitForDirectConn(ctx, p) if err != nil { return nil, err } } - s, err := c.NewStream(ctx) + str, err := c.NewStream(ctx) if err != nil { if c.conn.IsClosed() { continue } return nil, err } - return s, nil + return str, nil + } + return nil, network.ErrNoConn +} + +// waitForDirectConn waits for a direct connection established through hole punching or connection reversal. +func (s *Swarm) waitForDirectConn(ctx context.Context, p peer.ID) (*Conn, error) { + s.directConnNotifs.Lock() + c := s.bestConnToPeer(p) + if c == nil { + s.directConnNotifs.Unlock() + return nil, network.ErrNoConn + } else if !c.Stat().Transient { + s.directConnNotifs.Unlock() + return c, nil + } + + // Wait for transient connection to upgrade to a direct connection either by + // connection reversal or hole punching. + ch := make(chan struct{}) + s.directConnNotifs.m[p] = append(s.directConnNotifs.m[p], ch) + s.directConnNotifs.Unlock() + + // Wait for notification. + // There's no point waiting for more than a minute here. + ctx, cancel := context.WithTimeout(ctx, time.Minute) + defer cancel() + select { + case <-ctx.Done(): + // Remove ourselves from the notification list + s.directConnNotifs.Lock() + s.directConnNotifs.m[p] = slices.DeleteFunc( + s.directConnNotifs.m[p], + func(c chan struct{}) bool { return c == ch }, + ) + if len(s.directConnNotifs.m[p]) == 0 { + delete(s.directConnNotifs.m, p) + } + s.directConnNotifs.Unlock() + return nil, ctx.Err() + case <-ch: + // We do not need to remove ourselves from the list here as the notifier + // clears the map + c := s.bestConnToPeer(p) + if c == nil { + return nil, network.ErrNoConn + } else if c.Stat().Transient { + return nil, network.ErrTransientConn + } + return c, nil } } diff --git a/p2p/test/basichost/basic_host_test.go b/p2p/test/basichost/basic_host_test.go index 6b010ed2aa..025f906464 100644 --- a/p2p/test/basichost/basic_host_test.go +++ b/p2p/test/basichost/basic_host_test.go @@ -4,13 +4,16 @@ import ( "context" "fmt" "testing" + "time" "github.com/libp2p/go-libp2p" "github.com/libp2p/go-libp2p/core/network" "github.com/libp2p/go-libp2p/core/peer" + "github.com/libp2p/go-libp2p/core/peerstore" "github.com/libp2p/go-libp2p/p2p/protocol/circuitv2/client" "github.com/libp2p/go-libp2p/p2p/protocol/circuitv2/relay" ma "github.com/multiformats/go-multiaddr" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -62,11 +65,81 @@ func TestNoStreamOverTransientConnection(t *testing.T) { err = h1.Connect(context.Background(), h2Info) require.NoError(t, err) - ctx := network.WithNoDial(context.Background(), "test") + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + ctx = network.WithNoDial(ctx, "test") _, err = h1.NewStream(ctx, h2.ID(), "/testprotocol") - require.ErrorIs(t, err, network.ErrTransientConn) + require.Error(t, err) _, err = h1.NewStream(network.WithUseTransient(context.Background(), "test"), h2.ID(), "/testprotocol") require.NoError(t, err) } + +func TestNewStreamTransientConnection(t *testing.T) { + h1, err := libp2p.New( + libp2p.ListenAddrStrings("/ip4/127.0.0.1/udp/0/quic-v1"), + libp2p.EnableRelay(), + ) + require.NoError(t, err) + + h2, err := libp2p.New( + libp2p.NoListenAddrs, + libp2p.EnableRelay(), + ) + require.NoError(t, err) + + relay1, err := libp2p.New() + require.NoError(t, err) + + _, err = relay.New(relay1) + require.NoError(t, err) + + relay1info := peer.AddrInfo{ + ID: relay1.ID(), + Addrs: relay1.Addrs(), + } + err = h1.Connect(context.Background(), relay1info) + require.NoError(t, err) + + err = h2.Connect(context.Background(), relay1info) + require.NoError(t, err) + + h2.SetStreamHandler("/testprotocol", func(s network.Stream) { + fmt.Println("testprotocol") + + // End the example + s.Close() + }) + + _, err = client.Reserve(context.Background(), h2, relay1info) + require.NoError(t, err) + + relayaddr := ma.StringCast("/p2p/" + relay1info.ID.String() + "/p2p-circuit/p2p/" + h2.ID().String()) + + h1.Peerstore().AddAddr(h2.ID(), relayaddr, peerstore.TempAddrTTL) + + // NewStream should block transient connections till we have a direct connection + ctx, cancel := context.WithTimeout(context.Background(), 200*time.Millisecond) + defer cancel() + s, err := h1.NewStream(ctx, h2.ID(), "/testprotocol") + require.ErrorIs(t, err, context.DeadlineExceeded) + require.Nil(t, s) + + // NewStream should return a stream if a direct connection is established + // while waiting + ctx, cancel = context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + time.AfterFunc(time.Second, func() { + // connect h2 to h1 simulating connection reversal + h2.Peerstore().AddAddrs(h1.ID(), h1.Addrs(), peerstore.TempAddrTTL) + ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second) + defer cancel() + ctx = network.WithForceDirectDial(ctx, "test") + err := h2.Connect(ctx, peer.AddrInfo{ID: h1.ID()}) + assert.NoError(t, err) + }) + s, err = h1.NewStream(ctx, h2.ID(), "/testprotocol") + require.NoError(t, err) + require.NotNil(t, s) +} diff --git a/p2p/test/swarm/swarm_test.go b/p2p/test/swarm/swarm_test.go index 2ddadb3576..e5bcdbffbe 100644 --- a/p2p/test/swarm/swarm_test.go +++ b/p2p/test/swarm/swarm_test.go @@ -3,6 +3,7 @@ package swarm_test import ( "context" "testing" + "time" "github.com/libp2p/go-libp2p" "github.com/libp2p/go-libp2p/core/network" @@ -11,6 +12,7 @@ import ( "github.com/libp2p/go-libp2p/p2p/protocol/circuitv2/client" "github.com/libp2p/go-libp2p/p2p/protocol/circuitv2/relay" ma "github.com/multiformats/go-multiaddr" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -68,3 +70,64 @@ func TestDialPeerTransientConnection(t *testing.T) { require.Error(t, err) require.Nil(t, conn) } + +func TestNewStreamTransientConnection(t *testing.T) { + h1, err := libp2p.New( + libp2p.ListenAddrStrings("/ip4/127.0.0.1/udp/0/quic-v1"), + libp2p.EnableRelay(), + ) + require.NoError(t, err) + + h2, err := libp2p.New( + libp2p.NoListenAddrs, + libp2p.EnableRelay(), + ) + require.NoError(t, err) + + relay1, err := libp2p.New() + require.NoError(t, err) + + _, err = relay.New(relay1) + require.NoError(t, err) + + relay1info := peer.AddrInfo{ + ID: relay1.ID(), + Addrs: relay1.Addrs(), + } + err = h1.Connect(context.Background(), relay1info) + require.NoError(t, err) + + err = h2.Connect(context.Background(), relay1info) + require.NoError(t, err) + + _, err = client.Reserve(context.Background(), h2, relay1info) + require.NoError(t, err) + + relayaddr := ma.StringCast("/p2p/" + relay1info.ID.String() + "/p2p-circuit/p2p/" + h2.ID().String()) + + h1.Peerstore().AddAddr(h2.ID(), relayaddr, peerstore.TempAddrTTL) + + // NewStream should block transient connections till we have a direct connection + ctx, cancel := context.WithTimeout(context.Background(), 200*time.Millisecond) + defer cancel() + s, err := h1.Network().NewStream(ctx, h2.ID()) + require.ErrorIs(t, err, context.DeadlineExceeded) + require.Nil(t, s) + + // NewStream should return a stream if a direct connection is established + // while waiting + ctx, cancel = context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + time.AfterFunc(time.Second, func() { + // connect h2 to h1 simulating connection reversal + h2.Peerstore().AddAddrs(h1.ID(), h1.Addrs(), peerstore.TempAddrTTL) + ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second) + defer cancel() + ctx = network.WithForceDirectDial(ctx, "test") + err := h2.Connect(ctx, peer.AddrInfo{ID: h1.ID()}) + assert.NoError(t, err) + }) + s, err = h1.Network().NewStream(ctx, h2.ID()) + require.NoError(t, err) + require.NotNil(t, s) +} From 7d094323d17e96166fe6acfc7f3bf627f42a8ace Mon Sep 17 00:00:00 2001 From: sukun Date: Mon, 18 Sep 2023 11:08:38 +0530 Subject: [PATCH 2/3] address review comments --- p2p/net/swarm/swarm.go | 34 ++++++++++++++++++---------------- 1 file changed, 18 insertions(+), 16 deletions(-) diff --git a/p2p/net/swarm/swarm.go b/p2p/net/swarm/swarm.go index b40c49d36e..9ecb5dbf8d 100644 --- a/p2p/net/swarm/swarm.go +++ b/p2p/net/swarm/swarm.go @@ -398,17 +398,17 @@ func (s *Swarm) addConn(tc transport.CapableConn, dir network.Direction) (*Conn, s.conns.Unlock() // Notify goroutines waiting for a direct connection - - // Go routines interested in waiting for direct connection first acquire this lock and then - // acquire conns.RLock. Do not acquire this lock before conns.Unlock to prevent deadlock. - s.directConnNotifs.Lock() if !c.Stat().Transient { + // Go routines interested in waiting for direct connection first acquire this lock + // and then acquire s.conns.RLock. Do not acquire this lock before conns.Unlock to + // prevent deadlock. + s.directConnNotifs.Lock() for _, ch := range s.directConnNotifs.m[p] { close(ch) } delete(s.directConnNotifs.m, p) + s.directConnNotifs.Unlock() } - s.directConnNotifs.Unlock() // Emit event after releasing `s.conns` lock so that a consumer can still // use swarm methods that need the `s.conns` lock. @@ -466,15 +466,15 @@ func (s *Swarm) NewStream(ctx context.Context, p peer.ID) (network.Stream, error // // TODO: Try all connections even if we get an error opening a stream on // a non-closed connection. - dialed := false - for i := 0; i < 1; i++ { + numDials := 0 + for { c := s.bestConnToPeer(p) if c == nil { if nodial, _ := network.GetNoDial(ctx); !nodial { - if dialed { + numDials++ + if numDials > DialAttempts { return nil, errors.New("max dial attempts exceeded") } - dialed = true var err error c, err = s.dialPeer(ctx, p) if err != nil { @@ -503,7 +503,6 @@ func (s *Swarm) NewStream(ctx context.Context, p peer.ID) (network.Stream, error } return str, nil } - return nil, network.ErrNoConn } // waitForDirectConn waits for a direct connection established through hole punching or connection reversal. @@ -524,14 +523,17 @@ func (s *Swarm) waitForDirectConn(ctx context.Context, p peer.ID) (*Conn, error) s.directConnNotifs.m[p] = append(s.directConnNotifs.m[p], ch) s.directConnNotifs.Unlock() - // Wait for notification. - // There's no point waiting for more than a minute here. - ctx, cancel := context.WithTimeout(ctx, time.Minute) + // apply the DialPeer timeout + ctx, cancel := context.WithTimeout(ctx, network.GetDialPeerTimeout(ctx)) defer cancel() + + // Wait for notification. select { case <-ctx.Done(): // Remove ourselves from the notification list s.directConnNotifs.Lock() + defer s.directConnNotifs.Unlock() + s.directConnNotifs.m[p] = slices.DeleteFunc( s.directConnNotifs.m[p], func(c chan struct{}) bool { return c == ch }, @@ -539,15 +541,15 @@ func (s *Swarm) waitForDirectConn(ctx context.Context, p peer.ID) (*Conn, error) if len(s.directConnNotifs.m[p]) == 0 { delete(s.directConnNotifs.m, p) } - s.directConnNotifs.Unlock() return nil, ctx.Err() case <-ch: // We do not need to remove ourselves from the list here as the notifier - // clears the map + // clears the map entry c := s.bestConnToPeer(p) if c == nil { return nil, network.ErrNoConn - } else if c.Stat().Transient { + } + if c.Stat().Transient { return nil, network.ErrTransientConn } return c, nil From e1d57287e4d835232c75eb31d31d440f80c71081 Mon Sep 17 00:00:00 2001 From: sukun Date: Mon, 16 Oct 2023 02:45:40 +0530 Subject: [PATCH 3/3] swarm: fix Transient stream tests to not use a timer --- p2p/test/basichost/basic_host_test.go | 27 ++++++++++----- p2p/test/swarm/swarm_test.go | 49 +++++++++++++++++++++------ 2 files changed, 58 insertions(+), 18 deletions(-) diff --git a/p2p/test/basichost/basic_host_test.go b/p2p/test/basichost/basic_host_test.go index 025f906464..13800fef99 100644 --- a/p2p/test/basichost/basic_host_test.go +++ b/p2p/test/basichost/basic_host_test.go @@ -84,7 +84,7 @@ func TestNewStreamTransientConnection(t *testing.T) { require.NoError(t, err) h2, err := libp2p.New( - libp2p.NoListenAddrs, + libp2p.ListenAddrStrings("/ip4/127.0.0.1/udp/0/quic-v1"), libp2p.EnableRelay(), ) require.NoError(t, err) @@ -128,9 +128,20 @@ func TestNewStreamTransientConnection(t *testing.T) { // NewStream should return a stream if a direct connection is established // while waiting - ctx, cancel = context.WithTimeout(context.Background(), 5*time.Second) - defer cancel() - time.AfterFunc(time.Second, func() { + done := make(chan bool, 2) + go func() { + h1.Peerstore().AddAddrs(h2.ID(), h2.Addrs(), peerstore.TempAddrTTL) + ctx, cancel = context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + ctx = network.WithNoDial(ctx, "test") + s, err = h1.NewStream(ctx, h2.ID(), "/testprotocol") + require.NoError(t, err) + require.NotNil(t, s) + defer s.Close() + require.Equal(t, s.Conn().Stat().Direction, network.DirInbound) + done <- true + }() + go func() { // connect h2 to h1 simulating connection reversal h2.Peerstore().AddAddrs(h1.ID(), h1.Addrs(), peerstore.TempAddrTTL) ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second) @@ -138,8 +149,8 @@ func TestNewStreamTransientConnection(t *testing.T) { ctx = network.WithForceDirectDial(ctx, "test") err := h2.Connect(ctx, peer.AddrInfo{ID: h1.ID()}) assert.NoError(t, err) - }) - s, err = h1.NewStream(ctx, h2.ID(), "/testprotocol") - require.NoError(t, err) - require.NotNil(t, s) + done <- true + }() + <-done + <-done } diff --git a/p2p/test/swarm/swarm_test.go b/p2p/test/swarm/swarm_test.go index e5bcdbffbe..ec2ae60469 100644 --- a/p2p/test/swarm/swarm_test.go +++ b/p2p/test/swarm/swarm_test.go @@ -79,7 +79,7 @@ func TestNewStreamTransientConnection(t *testing.T) { require.NoError(t, err) h2, err := libp2p.New( - libp2p.NoListenAddrs, + libp2p.ListenAddrStrings("/ip4/127.0.0.1/udp/0/quic-v1"), libp2p.EnableRelay(), ) require.NoError(t, err) @@ -107,18 +107,46 @@ func TestNewStreamTransientConnection(t *testing.T) { h1.Peerstore().AddAddr(h2.ID(), relayaddr, peerstore.TempAddrTTL) - // NewStream should block transient connections till we have a direct connection + // WithUseTransient should succeed ctx, cancel := context.WithTimeout(context.Background(), 200*time.Millisecond) defer cancel() + ctx = network.WithUseTransient(ctx, "test") s, err := h1.Network().NewStream(ctx, h2.ID()) + require.NoError(t, err) + require.NotNil(t, s) + defer s.Close() + + // Without WithUseTransient should fail with context deadline exceeded + ctx, cancel = context.WithTimeout(context.Background(), 200*time.Millisecond) + defer cancel() + s, err = h1.Network().NewStream(ctx, h2.ID()) require.ErrorIs(t, err, context.DeadlineExceeded) require.Nil(t, s) - // NewStream should return a stream if a direct connection is established - // while waiting - ctx, cancel = context.WithTimeout(context.Background(), 5*time.Second) + // Provide h2's direct address to h1. + h1.Peerstore().AddAddrs(h2.ID(), h2.Addrs(), peerstore.TempAddrTTL) + // network.NoDial should also fail + ctx, cancel = context.WithTimeout(context.Background(), 200*time.Millisecond) defer cancel() - time.AfterFunc(time.Second, func() { + ctx = network.WithNoDial(ctx, "test") + s, err = h1.Network().NewStream(ctx, h2.ID()) + require.ErrorIs(t, err, context.DeadlineExceeded) + require.Nil(t, s) + + done := make(chan bool, 2) + // NewStream should return a stream if an incoming direct connection is established + go func() { + ctx, cancel = context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + ctx = network.WithNoDial(ctx, "test") + s, err = h1.Network().NewStream(ctx, h2.ID()) + assert.NoError(t, err) + assert.NotNil(t, s) + defer s.Close() + require.Equal(t, s.Conn().Stat().Direction, network.DirInbound) + done <- true + }() + go func() { // connect h2 to h1 simulating connection reversal h2.Peerstore().AddAddrs(h1.ID(), h1.Addrs(), peerstore.TempAddrTTL) ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second) @@ -126,8 +154,9 @@ func TestNewStreamTransientConnection(t *testing.T) { ctx = network.WithForceDirectDial(ctx, "test") err := h2.Connect(ctx, peer.AddrInfo{ID: h1.ID()}) assert.NoError(t, err) - }) - s, err = h1.Network().NewStream(ctx, h2.ID()) - require.NoError(t, err) - require.NotNil(t, s) + done <- true + }() + + <-done + <-done }