diff --git a/go.mod b/go.mod index 8b84ea7..506b3b0 100644 --- a/go.mod +++ b/go.mod @@ -11,6 +11,7 @@ require ( github.com/libp2p/go-libp2p v0.33.0 github.com/libp2p/go-libp2p-kad-dht v0.25.2 github.com/libp2p/go-libp2p-record v0.2.0 + github.com/multiformats/go-multiaddr v0.12.2 github.com/multiformats/go-multibase v0.2.0 github.com/multiformats/go-multihash v0.2.3 github.com/prometheus/client_golang v1.19.0 @@ -80,7 +81,6 @@ require ( github.com/mr-tron/base58 v1.2.0 // indirect github.com/multiformats/go-base32 v0.1.0 // indirect github.com/multiformats/go-base36 v0.2.0 // indirect - github.com/multiformats/go-multiaddr v0.12.2 // indirect github.com/multiformats/go-multiaddr-dns v0.3.1 // indirect github.com/multiformats/go-multiaddr-fmt v0.1.0 // indirect github.com/multiformats/go-multicodec v0.9.0 // indirect diff --git a/server.go b/server.go index 3ca58a8..c16bb20 100644 --- a/server.go +++ b/server.go @@ -168,7 +168,7 @@ func getCombinedRouting(endpoints []string, dht routing.Routing) (router, error) routers = append(routers, clientRouter{Client: drclient}) } - return parallelRouter{ + return sanitizeRouter{parallelRouter{ routers: append(routers, libp2pRouter{routing: dht}), - }, nil + }}, nil } diff --git a/server_routers.go b/server_routers.go index fa32bd7..33430a9 100644 --- a/server_routers.go +++ b/server_routers.go @@ -3,6 +3,7 @@ package main import ( "context" "errors" + "reflect" "sync" "time" @@ -14,6 +15,7 @@ import ( "github.com/ipfs/go-cid" "github.com/libp2p/go-libp2p/core/peer" "github.com/libp2p/go-libp2p/core/routing" + manet "github.com/multiformats/go-multiaddr/net" ) type router interface { @@ -310,7 +312,7 @@ type libp2pRouter struct { func (d libp2pRouter) FindProviders(ctx context.Context, key cid.Cid, limit int) (iter.ResultIter[types.Record], error) { ctx, cancel := context.WithCancel(ctx) ch := d.routing.FindProvidersAsync(ctx, key, limit) - return iter.ToResultIter[types.Record](&peerChanIter{ + return iter.ToResultIter(&peerChanIter{ ch: ch, cancel: cancel, }), nil @@ -334,7 +336,7 @@ func (d libp2pRouter) FindPeers(ctx context.Context, pid peer.ID, limit int) (it rec.Addrs = append(rec.Addrs, types.Multiaddr{Multiaddr: addr}) } - return iter.ToResultIter[*types.PeerRecord](iter.FromSlice[*types.PeerRecord]([]*types.PeerRecord{rec})), nil + return iter.ToResultIter(iter.FromSlice([]*types.PeerRecord{rec})), nil } func (d libp2pRouter) GetIPNS(ctx context.Context, name ipns.Name) (*ipns.Record, error) { @@ -412,3 +414,83 @@ func (d clientRouter) FindProviders(ctx context.Context, cid cid.Cid, limit int) func (d clientRouter) FindPeers(ctx context.Context, pid peer.ID, limit int) (iter.ResultIter[*types.PeerRecord], error) { return d.Client.FindPeers(ctx, pid) } + +var _ server.ContentRouter = sanitizeRouter{} + +type sanitizeRouter struct { + router +} + +func (r sanitizeRouter) FindProviders(ctx context.Context, key cid.Cid, limit int) (iter.ResultIter[types.Record], error) { + it, err := r.router.FindProviders(ctx, key, limit) + if err != nil { + return nil, err + } + + return iter.Map(it, func(v iter.Result[types.Record]) iter.Result[types.Record] { + if v.Err != nil || v.Val == nil { + return v + } + + switch v.Val.GetSchema() { + case types.SchemaPeer: + result, ok := v.Val.(*types.PeerRecord) + if !ok { + logger.Errorw("problem casting find providers result", "Schema", v.Val.GetSchema(), "Type", reflect.TypeOf(v).String()) + return v + } + + result.Addrs = filterPrivateMultiaddr(result.Addrs) + v.Val = result + + //lint:ignore SA1019 // ignore staticcheck + case types.SchemaBitswap: + //lint:ignore SA1019 // ignore staticcheck + result, ok := v.Val.(*types.BitswapRecord) + if !ok { + logger.Errorw("problem casting find providers result", "Schema", v.Val.GetSchema(), "Type", reflect.TypeOf(v).String()) + return v + } + + result.Addrs = filterPrivateMultiaddr(result.Addrs) + v.Val = result + } + + return v + }), nil +} + +func (r sanitizeRouter) FindPeers(ctx context.Context, pid peer.ID, limit int) (iter.ResultIter[*types.PeerRecord], error) { + it, err := r.router.FindPeers(ctx, pid, limit) + if err != nil { + return nil, err + } + + return iter.Map(it, func(v iter.Result[*types.PeerRecord]) iter.Result[*types.PeerRecord] { + if v.Err != nil || v.Val == nil { + return v + } + + v.Val.Addrs = filterPrivateMultiaddr(v.Val.Addrs) + return v + }), nil +} + +//lint:ignore SA1019 // ignore staticcheck +func (r sanitizeRouter) ProvideBitswap(ctx context.Context, req *server.BitswapWriteProvideRequest) (time.Duration, error) { + return 0, routing.ErrNotSupported +} + +func filterPrivateMultiaddr(a []types.Multiaddr) []types.Multiaddr { + b := make([]types.Multiaddr, 0, len(a)) + + for _, addr := range a { + if manet.IsPrivateAddr(addr.Multiaddr) { + continue + } + + b = append(b, addr) + } + + return b +} diff --git a/server_routers_test.go b/server_routers_test.go index 2ef826d..e5669f9 100644 --- a/server_routers_test.go +++ b/server_routers_test.go @@ -15,6 +15,7 @@ import ( "github.com/libp2p/go-libp2p/core/crypto" "github.com/libp2p/go-libp2p/core/peer" "github.com/libp2p/go-libp2p/core/routing" + "github.com/multiformats/go-multiaddr" "github.com/multiformats/go-multihash" "github.com/stretchr/testify/mock" "github.com/stretchr/testify/require" @@ -278,6 +279,12 @@ func makeCID() cid.Cid { return c } +func mustMultiaddr(t *testing.T, s string) types.Multiaddr { + ma, err := multiaddr.NewMultiaddr(s) + require.NoError(t, err) + return types.Multiaddr{Multiaddr: ma} +} + func TestFindProviders(t *testing.T) { t.Parallel() @@ -286,7 +293,8 @@ func TestFindProviders(t *testing.T) { c := makeCID() peers := []peer.ID{"peer1", "peer2", "peer3"} - d := parallelRouter{} + var d router + d = parallelRouter{} it, err := d.FindProviders(ctx, c, 10) require.NoError(t, err) @@ -300,17 +308,25 @@ func TestFindProviders(t *testing.T) { mr2Iter := newMockIter[types.Record](ctx) mr2.On("FindProviders", mock.Anything, c, 10).Return(mr2Iter, nil) - d = parallelRouter{ + d = sanitizeRouter{parallelRouter{ routers: []router{ &composableRouter{ providers: mr1, }, mr2, }, - } + }} + + privateAddr := mustMultiaddr(t, "/ip4/192.168.1.123/tcp/4001") + loopbackAddr := mustMultiaddr(t, "/ip4/127.0.0.1/tcp/4001") + publicAddr := mustMultiaddr(t, "/ip4/137.21.14.12/tcp/4001") go func() { - mr1Iter.ch <- iter.Result[types.Record]{Val: &types.PeerRecord{Schema: "peer", ID: &peers[0]}} + mr1Iter.ch <- iter.Result[types.Record]{Val: &types.PeerRecord{ + Schema: "peer", + ID: &peers[0], + Addrs: []types.Multiaddr{privateAddr, loopbackAddr, publicAddr}, + }} mr2Iter.ch <- iter.Result[types.Record]{Val: &types.PeerRecord{Schema: "peer", ID: &peers[0]}} mr1Iter.ch <- iter.Result[types.Record]{Val: &types.PeerRecord{Schema: "peer", ID: &peers[1]}} mr1Iter.ch <- iter.Result[types.Record]{Val: &types.PeerRecord{Schema: "peer", ID: &peers[2]}} @@ -326,6 +342,9 @@ func TestFindProviders(t *testing.T) { results, err := iter.ReadAllResults(it) require.NoError(t, err) require.Len(t, results, 5) + + require.Len(t, results[0].(*types.PeerRecord).Addrs, 1) + require.Equal(t, publicAddr.String(), results[0].(*types.PeerRecord).Addrs[0].String()) }) t.Run("Failed to Create All Iterators", func(t *testing.T) {