From 9c2be3c6f6c9d63bf2d228a740577b0a8542461e Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Wed, 13 Apr 2022 12:57:45 +0100 Subject: [PATCH 01/44] initial commit --- p2p/transport/webtransport/conn.go | 75 +++++++++++ p2p/transport/webtransport/crypto.go | 87 ++++++++++++ p2p/transport/webtransport/listener.go | 133 +++++++++++++++++++ p2p/transport/webtransport/multiaddr.go | 22 +++ p2p/transport/webtransport/stream.go | 55 ++++++++ p2p/transport/webtransport/transport.go | 104 +++++++++++++++ p2p/transport/webtransport/transport_test.go | 38 ++++++ 7 files changed, 514 insertions(+) create mode 100644 p2p/transport/webtransport/conn.go create mode 100644 p2p/transport/webtransport/crypto.go create mode 100644 p2p/transport/webtransport/listener.go create mode 100644 p2p/transport/webtransport/multiaddr.go create mode 100644 p2p/transport/webtransport/stream.go create mode 100644 p2p/transport/webtransport/transport.go create mode 100644 p2p/transport/webtransport/transport_test.go diff --git a/p2p/transport/webtransport/conn.go b/p2p/transport/webtransport/conn.go new file mode 100644 index 0000000000..40ba99f1a1 --- /dev/null +++ b/p2p/transport/webtransport/conn.go @@ -0,0 +1,75 @@ +package libp2pwebtransport + +import ( + "context" + "github.com/libp2p/go-libp2p-core/crypto" + "github.com/libp2p/go-libp2p-core/network" + "github.com/libp2p/go-libp2p-core/peer" + tpt "github.com/libp2p/go-libp2p-core/transport" + "github.com/marten-seemann/webtransport-go" + ma "github.com/multiformats/go-multiaddr" +) + +type conn struct { + transport tpt.Transport + wconn *webtransport.Conn +} + +var _ tpt.CapableConn = &conn{} + +func (c *conn) Close() error { + return c.wconn.Close() +} + +func (c *conn) IsClosed() bool { + panic("implement me") +} + +func (c *conn) OpenStream(ctx context.Context) (network.MuxedStream, error) { + str, err := c.wconn.OpenStreamSync(ctx) + return &stream{str}, err +} + +func (c *conn) AcceptStream() (network.MuxedStream, error) { + str, err := c.wconn.AcceptStream(context.Background()) + return &stream{str}, err +} + +func (c *conn) LocalPeer() peer.ID { + // TODO implement me + panic("implement me") +} + +func (c *conn) LocalPrivateKey() crypto.PrivKey { + // TODO implement me + panic("implement me") +} + +func (c *conn) RemotePeer() peer.ID { + // TODO implement me + panic("implement me") +} + +func (c *conn) RemotePublicKey() crypto.PubKey { + // TODO implement me + panic("implement me") +} + +func (c *conn) LocalMultiaddr() ma.Multiaddr { + // TODO implement me + panic("implement me") +} + +func (c *conn) RemoteMultiaddr() ma.Multiaddr { + // TODO implement me + panic("implement me") +} + +func (c *conn) Scope() network.ConnScope { + // TODO implement me + panic("implement me") +} + +func (c *conn) Transport() tpt.Transport { + return c.transport +} diff --git a/p2p/transport/webtransport/crypto.go b/p2p/transport/webtransport/crypto.go new file mode 100644 index 0000000000..f658eee3a5 --- /dev/null +++ b/p2p/transport/webtransport/crypto.go @@ -0,0 +1,87 @@ +package libp2pwebtransport + +import ( + "crypto/ecdsa" + "crypto/elliptic" + "crypto/rand" + "crypto/sha256" + "crypto/tls" + "crypto/x509" + "crypto/x509/pkix" + "math/big" + "time" +) + +func certificateHash(c *tls.Config) [32]byte { + return sha256.Sum256(c.Certificates[0].Certificate[0]) +} + +// TODO: don't generate a chain here. A single cert should be sufficient (and save bytes). +func getTLSConf() (*tls.Config, error) { + ca, caPrivateKey, err := generateCA() + if err != nil { + return nil, err + } + leafCert, leafPrivateKey, err := generateLeafCert(ca, caPrivateKey) + if err != nil { + return nil, err + } + certPool := x509.NewCertPool() + certPool.AddCert(ca) + return &tls.Config{ + Certificates: []tls.Certificate{{ + Certificate: [][]byte{leafCert.Raw}, + PrivateKey: leafPrivateKey, + }}, + }, nil +} + +func generateCA() (*x509.Certificate, *ecdsa.PrivateKey, error) { + certTempl := &x509.Certificate{ + SerialNumber: big.NewInt(2019), + Subject: pkix.Name{}, + NotBefore: time.Now().Add(-time.Hour), + NotAfter: time.Now().Add(24 * time.Hour), + IsCA: true, + ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth, x509.ExtKeyUsageServerAuth}, + KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageCertSign, + BasicConstraintsValid: true, + } + caPrivateKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + if err != nil { + return nil, nil, err + } + caBytes, err := x509.CreateCertificate(rand.Reader, certTempl, certTempl, &caPrivateKey.PublicKey, caPrivateKey) + if err != nil { + return nil, nil, err + } + ca, err := x509.ParseCertificate(caBytes) + if err != nil { + return nil, nil, err + } + return ca, caPrivateKey, nil +} + +func generateLeafCert(ca *x509.Certificate, caPrivateKey *ecdsa.PrivateKey) (*x509.Certificate, *ecdsa.PrivateKey, error) { + certTempl := &x509.Certificate{ + SerialNumber: big.NewInt(1), + DNSNames: []string{"localhost"}, + NotBefore: time.Now().Add(-time.Hour), + NotAfter: time.Now().Add(24 * time.Hour), + ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth, x509.ExtKeyUsageServerAuth}, + KeyUsage: x509.KeyUsageDigitalSignature, + } + privKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + if err != nil { + return nil, nil, err + } + certBytes, err := x509.CreateCertificate(rand.Reader, certTempl, ca, &privKey.PublicKey, caPrivateKey) + if err != nil { + return nil, nil, err + } + cert, err := x509.ParseCertificate(certBytes) + if err != nil { + return nil, nil, err + } + return cert, privKey, nil +} diff --git a/p2p/transport/webtransport/listener.go b/p2p/transport/webtransport/listener.go new file mode 100644 index 0000000000..e422e23763 --- /dev/null +++ b/p2p/transport/webtransport/listener.go @@ -0,0 +1,133 @@ +package libp2pwebtransport + +import ( + "crypto/tls" + "errors" + "github.com/multiformats/go-multibase" + "github.com/multiformats/go-multihash" + "net" + "net/http" + + tpt "github.com/libp2p/go-libp2p-core/transport" + + "github.com/lucas-clemente/quic-go/http3" + "github.com/marten-seemann/webtransport-go" + ma "github.com/multiformats/go-multiaddr" + manet "github.com/multiformats/go-multiaddr/net" +) + +var errClosed = errors.New("closed") + +type listener struct { + server webtransport.Server + tlsConf *tls.Config + + closed chan struct{} // is closed when Close is called + serverClosed chan struct{} // is closed when server.Serve returns + + addr net.Addr + multiaddr ma.Multiaddr + + queue chan *webtransport.Conn +} + +var _ tpt.Listener = &listener{} + +func newListener(laddr ma.Multiaddr, tlsConf *tls.Config) (tpt.Listener, error) { + network, addr, err := manet.DialArgs(laddr) + if err != nil { + return nil, err + } + udpAddr, err := net.ResolveUDPAddr(network, addr) + if err != nil { + return nil, err + } + udpConn, err := net.ListenUDP(network, udpAddr) + if err != nil { + return nil, err + } + localMultiaddr, err := toWebtransportMultiaddr(udpConn.LocalAddr()) + if err != nil { + return nil, err + } + ln := &listener{ + queue: make(chan *webtransport.Conn, 10), + serverClosed: make(chan struct{}), + addr: udpConn.LocalAddr(), + tlsConf: tlsConf, + multiaddr: localMultiaddr, + } + server := webtransport.Server{ + H3: http3.Server{ + Server: &http.Server{ + TLSConfig: tlsConf, + }, + }, + } + mux := http.NewServeMux() + mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { + w.Write([]byte("Hello, world!")) + }) + mux.HandleFunc(webtransportHTTPEndpoint, func(w http.ResponseWriter, r *http.Request) { + // TODO: check ?type=multistream URL param + c, err := server.Upgrade(w, r) + if err != nil { + w.WriteHeader(500) + return + } + ln.queue <- c + }) + server.H3.Handler = mux + go func() { + defer close(ln.serverClosed) + defer func() { udpConn.Close() }() + if err := server.Serve(udpConn); err != nil { + // TODO: only output if the server hasn't been closed + log.Debugw("serving failed", "addr", udpConn.LocalAddr(), "error", err) + } + }() + ln.server = server + return ln, nil +} + +func (l *listener) Accept() (tpt.CapableConn, error) { + select { + case <-l.closed: + return nil, errClosed + default: + } + + var c *webtransport.Conn + select { + case c = <-l.queue: + // TODO: libp2p handshake + // TODO: pass in transport + return &conn{wconn: c}, nil + case <-l.closed: + return nil, errClosed + } +} + +func (l *listener) Addr() net.Addr { + return l.addr +} + +func (l *listener) Multiaddr() ma.Multiaddr { + certHash := certificateHash(l.tlsConf) + h, err := multihash.Encode(certHash[:], multihash.SHA2_256) + if err != nil { + panic(err) + } + certHashStr, err := multibase.Encode(multibase.Base58BTC, h) + if err != nil { + panic(err) + } + return l.multiaddr.Encapsulate(ma.StringCast("/certhash/" + certHashStr)) +} + +func (l *listener) Close() error { + close(l.closed) + err := l.server.Close() + <-l.serverClosed + return err +} diff --git a/p2p/transport/webtransport/multiaddr.go b/p2p/transport/webtransport/multiaddr.go new file mode 100644 index 0000000000..a84853c1ee --- /dev/null +++ b/p2p/transport/webtransport/multiaddr.go @@ -0,0 +1,22 @@ +package libp2pwebtransport + +import ( + "net" + + ma "github.com/multiformats/go-multiaddr" + manet "github.com/multiformats/go-multiaddr/net" +) + +var webtransportMA = ma.StringCast("/quic/webtransport") + +func toWebtransportMultiaddr(na net.Addr) (ma.Multiaddr, error) { + udpMA, err := manet.FromNetAddr(na) + if err != nil { + return nil, err + } + return udpMA.Encapsulate(webtransportMA), nil +} + +func fromWebtransportMultiaddr(addr ma.Multiaddr) (net.Addr, error) { + return manet.ToNetAddr(addr.Decapsulate(webtransportMA)) +} diff --git a/p2p/transport/webtransport/stream.go b/p2p/transport/webtransport/stream.go new file mode 100644 index 0000000000..cebfcfb2fb --- /dev/null +++ b/p2p/transport/webtransport/stream.go @@ -0,0 +1,55 @@ +package libp2pwebtransport + +import ( + "errors" + + "github.com/marten-seemann/webtransport-go" + + "github.com/libp2p/go-libp2p-core/network" +) + +const ( + reset webtransport.ErrorCode = 0 +) + +type stream struct { + webtransport.Stream +} + +var _ network.MuxedStream = &stream{} + +func (s *stream) Read(b []byte) (n int, err error) { + n, err = s.Stream.Read(b) + if err != nil && errors.Is(err, &webtransport.StreamError{}) { + err = network.ErrReset + } + return n, err +} + +func (s *stream) Write(b []byte) (n int, err error) { + n, err = s.Stream.Write(b) + if err != nil && errors.Is(err, &webtransport.StreamError{}) { + err = network.ErrReset + } + return n, err +} + +func (s *stream) Reset() error { + s.Stream.CancelRead(reset) + s.Stream.CancelWrite(reset) + return nil +} + +func (s *stream) Close() error { + s.Stream.CancelRead(reset) + return s.Stream.Close() +} + +func (s *stream) CloseRead() error { + s.Stream.CancelRead(reset) + return nil +} + +func (s *stream) CloseWrite() error { + return s.Stream.Close() +} diff --git a/p2p/transport/webtransport/transport.go b/p2p/transport/webtransport/transport.go new file mode 100644 index 0000000000..1c91413c1b --- /dev/null +++ b/p2p/transport/webtransport/transport.go @@ -0,0 +1,104 @@ +package libp2pwebtransport + +import ( + "context" + "crypto/tls" + "fmt" + "github.com/multiformats/go-multibase" + "github.com/multiformats/go-multihash" + "sync" + + "github.com/libp2p/go-libp2p-core/peer" + tpt "github.com/libp2p/go-libp2p-core/transport" + + logging "github.com/ipfs/go-log/v2" + "github.com/marten-seemann/webtransport-go" + ma "github.com/multiformats/go-multiaddr" + mafmt "github.com/multiformats/go-multiaddr-fmt" + manet "github.com/multiformats/go-multiaddr/net" +) + +var log = logging.Logger("webtransport") + +const webtransportHTTPEndpoint = "/.well-known/libp2p-webtransport" + +type transport struct { + tlsConf *tls.Config + dialer webtransport.Dialer + + initOnce sync.Once + server webtransport.Server +} + +var _ tpt.Transport = &transport{} + +func New() (tpt.Transport, error) { + tlsConf, err := getTLSConf() // TODO: only do this when initializing a listener + if err != nil { + return nil, err + } + return &transport{ + tlsConf: tlsConf, + dialer: webtransport.Dialer{ + TLSClientConf: &tls.Config{InsecureSkipVerify: true}, // TODO: verify certificate, + }, + }, nil +} + +func (t *transport) Dial(ctx context.Context, raddr ma.Multiaddr, p peer.ID) (tpt.CapableConn, error) { + _, addr, err := manet.DialArgs(raddr) + if err != nil { + return nil, err + } + url := fmt.Sprintf("https://%s%s", addr, webtransportHTTPEndpoint) + certHashesStr := make([]string, 0, 2) + ma.ForEach(raddr, func(c ma.Component) bool { + if c.Protocol().Code == ma.P_CERTHASH { + certHashesStr = append(certHashesStr, c.Value()) + } + return true + }) + var certHashes []multihash.DecodedMultihash + for _, s := range certHashesStr { + _, ch, err := multibase.Decode(s) + if err != nil { + return nil, fmt.Errorf("failed to multibase-decode certificate hash: %w", err) + } + dh, err := multihash.Decode(ch) + if err != nil { + return nil, fmt.Errorf("failed to multihash-decode certificate hash: %w", err) + } + certHashes = append(certHashes, *dh) + } + rsp, wconn, err := t.dialer.Dial(ctx, url, nil) + if err != nil { + return nil, err + } + defer rsp.Body.Close() + if rsp.StatusCode < 200 || rsp.StatusCode > 299 { + return nil, fmt.Errorf("invalid response status code: %d", rsp.StatusCode) + } + // TODO: run handshake on conn + return &conn{ + transport: t, + wconn: wconn, + }, nil +} + +var dialMatcher = mafmt.And(mafmt.IP, mafmt.Base(ma.P_UDP), mafmt.Base(ma.P_QUIC), mafmt.Base(ma.P_WEBTRANSPORT)) + +func (t *transport) CanDial(addr ma.Multiaddr) bool { + return dialMatcher.Matches(addr) +} + +func (t *transport) Listen(laddr ma.Multiaddr) (tpt.Listener, error) { + return newListener(laddr, t.tlsConf) +} + +func (t *transport) Protocols() []int { + return []int{ma.P_WEBTRANSPORT} +} + +func (t *transport) Proxy() bool { + return false +} diff --git a/p2p/transport/webtransport/transport_test.go b/p2p/transport/webtransport/transport_test.go new file mode 100644 index 0000000000..bf65017dcf --- /dev/null +++ b/p2p/transport/webtransport/transport_test.go @@ -0,0 +1,38 @@ +package libp2pwebtransport_test + +import ( + "context" + "io" + "testing" + + libp2pwebtransport "github.com/marten-seemann/go-libp2p-webtransport" + ma "github.com/multiformats/go-multiaddr" + "github.com/stretchr/testify/require" +) + +func TestTransport(t *testing.T) { + tr, err := libp2pwebtransport.New() + require.NoError(t, err) + ln, err := tr.Listen(ma.StringCast("/ip4/127.0.0.1/udp/0/quic/webtransport")) + require.NoError(t, err) + + go func() { + tr2, err := libp2pwebtransport.New() + require.NoError(t, err) + conn, err := tr2.Dial(context.Background(), ln.Multiaddr(), "peer") + require.NoError(t, err) + str, err := conn.OpenStream(context.Background()) + require.NoError(t, err) + _, err = str.Write([]byte("foobar")) + require.NoError(t, err) + require.NoError(t, str.Close()) + }() + + conn, err := ln.Accept() + require.NoError(t, err) + str, err := conn.AcceptStream() + require.NoError(t, err) + data, err := io.ReadAll(str) + require.NoError(t, err) + require.Equal(t, "foobar", string(data)) +} From 528fa96706a43db1a0388642bf40bb9957fa8e14 Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Sat, 16 Apr 2022 14:52:10 +0100 Subject: [PATCH 02/44] use a context to control listener shutdown --- p2p/transport/webtransport/listener.go | 17 +++++++++++------ 1 file changed, 11 insertions(+), 6 deletions(-) diff --git a/p2p/transport/webtransport/listener.go b/p2p/transport/webtransport/listener.go index e422e23763..9d95acd9e2 100644 --- a/p2p/transport/webtransport/listener.go +++ b/p2p/transport/webtransport/listener.go @@ -1,10 +1,10 @@ package libp2pwebtransport import ( + "context" "crypto/tls" "errors" - "github.com/multiformats/go-multibase" - "github.com/multiformats/go-multihash" + "net" "net/http" @@ -14,6 +14,8 @@ import ( "github.com/marten-seemann/webtransport-go" ma "github.com/multiformats/go-multiaddr" manet "github.com/multiformats/go-multiaddr/net" + "github.com/multiformats/go-multibase" + "github.com/multiformats/go-multihash" ) var errClosed = errors.New("closed") @@ -22,7 +24,9 @@ type listener struct { server webtransport.Server tlsConf *tls.Config - closed chan struct{} // is closed when Close is called + ctx context.Context + ctxCancel context.CancelFunc + serverClosed chan struct{} // is closed when server.Serve returns addr net.Addr @@ -57,6 +61,7 @@ func newListener(laddr ma.Multiaddr, tlsConf *tls.Config) (tpt.Listener, error) tlsConf: tlsConf, multiaddr: localMultiaddr, } + ln.ctx, ln.ctxCancel = context.WithCancel(context.Background()) server := webtransport.Server{ H3: http3.Server{ Server: &http.Server{ @@ -92,7 +97,7 @@ func newListener(laddr ma.Multiaddr, tlsConf *tls.Config) (tpt.Listener, error) func (l *listener) Accept() (tpt.CapableConn, error) { select { - case <-l.closed: + case <-l.ctx.Done(): return nil, errClosed default: } @@ -103,7 +108,7 @@ func (l *listener) Accept() (tpt.CapableConn, error) { // TODO: libp2p handshake // TODO: pass in transport return &conn{wconn: c}, nil - case <-l.closed: + case <-l.ctx.Done(): return nil, errClosed } } @@ -126,7 +131,7 @@ func (l *listener) Multiaddr() ma.Multiaddr { } func (l *listener) Close() error { - close(l.closed) + l.ctxCancel() err := l.server.Close() <-l.serverClosed return err From c7ebc6d8343097c5443c5f74da1330e274465416 Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Sat, 16 Apr 2022 15:15:06 +0100 Subject: [PATCH 03/44] run a Noise handshake on the first stream To make this actually secure, we still need to verify the certificate hashes. --- p2p/transport/webtransport/conn.go | 49 ++++++++------ p2p/transport/webtransport/listener.go | 67 +++++++++++++++----- p2p/transport/webtransport/stream.go | 16 +++++ p2p/transport/webtransport/transport.go | 39 ++++++++++-- p2p/transport/webtransport/transport_test.go | 16 ++++- 5 files changed, 143 insertions(+), 44 deletions(-) diff --git a/p2p/transport/webtransport/conn.go b/p2p/transport/webtransport/conn.go index 40ba99f1a1..e45daedd86 100644 --- a/p2p/transport/webtransport/conn.go +++ b/p2p/transport/webtransport/conn.go @@ -2,10 +2,13 @@ package libp2pwebtransport import ( "context" + "github.com/libp2p/go-libp2p-core/crypto" + ic "github.com/libp2p/go-libp2p-core/crypto" "github.com/libp2p/go-libp2p-core/network" "github.com/libp2p/go-libp2p-core/peer" tpt "github.com/libp2p/go-libp2p-core/transport" + "github.com/marten-seemann/webtransport-go" ma "github.com/multiformats/go-multiaddr" ) @@ -13,6 +16,29 @@ import ( type conn struct { transport tpt.Transport wconn *webtransport.Conn + + localPeer, remotePeer peer.ID + privKey ic.PrivKey + remotePubKey ic.PubKey +} + +func newConn(tr tpt.Transport, wconn *webtransport.Conn, privKey ic.PrivKey, remotePubKey ic.PubKey) (*conn, error) { + localPeer, err := peer.IDFromPrivateKey(privKey) + if err != nil { + return nil, err + } + remotePeer, err := peer.IDFromPublicKey(remotePubKey) + if err != nil { + return nil, err + } + return &conn{ + transport: tr, + wconn: wconn, + privKey: privKey, + localPeer: localPeer, + remotePeer: remotePeer, + remotePubKey: remotePubKey, + }, nil } var _ tpt.CapableConn = &conn{} @@ -35,25 +61,10 @@ func (c *conn) AcceptStream() (network.MuxedStream, error) { return &stream{str}, err } -func (c *conn) LocalPeer() peer.ID { - // TODO implement me - panic("implement me") -} - -func (c *conn) LocalPrivateKey() crypto.PrivKey { - // TODO implement me - panic("implement me") -} - -func (c *conn) RemotePeer() peer.ID { - // TODO implement me - panic("implement me") -} - -func (c *conn) RemotePublicKey() crypto.PubKey { - // TODO implement me - panic("implement me") -} +func (c *conn) LocalPeer() peer.ID { return c.localPeer } +func (c *conn) LocalPrivateKey() crypto.PrivKey { return c.privKey } +func (c *conn) RemotePeer() peer.ID { return c.remotePeer } +func (c *conn) RemotePublicKey() crypto.PubKey { return c.remotePubKey } func (c *conn) LocalMultiaddr() ma.Multiaddr { // TODO implement me diff --git a/p2p/transport/webtransport/listener.go b/p2p/transport/webtransport/listener.go index 9d95acd9e2..0dadec2fa8 100644 --- a/p2p/transport/webtransport/listener.go +++ b/p2p/transport/webtransport/listener.go @@ -4,12 +4,14 @@ import ( "context" "crypto/tls" "errors" - "net" "net/http" + "time" tpt "github.com/libp2p/go-libp2p-core/transport" + noise "github.com/libp2p/go-libp2p-noise" + "github.com/lucas-clemente/quic-go/http3" "github.com/marten-seemann/webtransport-go" ma "github.com/multiformats/go-multiaddr" @@ -20,7 +22,13 @@ import ( var errClosed = errors.New("closed") +const queueLen = 16 +const handshakeTimeout = 10 * time.Second + type listener struct { + transport tpt.Transport + noise *noise.Transport + server webtransport.Server tlsConf *tls.Config @@ -37,7 +45,7 @@ type listener struct { var _ tpt.Listener = &listener{} -func newListener(laddr ma.Multiaddr, tlsConf *tls.Config) (tpt.Listener, error) { +func newListener(laddr ma.Multiaddr, tlsConf *tls.Config, transport tpt.Transport, noise *noise.Transport) (tpt.Listener, error) { network, addr, err := manet.DialArgs(laddr) if err != nil { return nil, err @@ -55,7 +63,9 @@ func newListener(laddr ma.Multiaddr, tlsConf *tls.Config) (tpt.Listener, error) return nil, err } ln := &listener{ - queue: make(chan *webtransport.Conn, 10), + transport: transport, + noise: noise, + queue: make(chan *webtransport.Conn, queueLen), serverClosed: make(chan struct{}), addr: udpConn.LocalAddr(), tlsConf: tlsConf, @@ -80,6 +90,7 @@ func newListener(laddr ma.Multiaddr, tlsConf *tls.Config) (tpt.Listener, error) w.WriteHeader(500) return } + // TODO: handle queue overflow ln.queue <- c }) server.H3.Handler = mux @@ -96,21 +107,47 @@ func newListener(laddr ma.Multiaddr, tlsConf *tls.Config) (tpt.Listener, error) } func (l *listener) Accept() (tpt.CapableConn, error) { - select { - case <-l.ctx.Done(): - return nil, errClosed - default: + queue := make(chan tpt.CapableConn, queueLen) + for { + select { + case <-l.ctx.Done(): + return nil, errClosed + default: + } + + var c *webtransport.Conn + select { + case c = <-l.queue: + go func(c *webtransport.Conn) { + ctx, cancel := context.WithTimeout(l.ctx, handshakeTimeout) + defer cancel() + conn, err := l.handshake(ctx, c) + if err != nil { + log.Debugw("handshake failed", "error", err) + c.Close() + return + } + // TODO: handle queue overflow + queue <- conn + }(c) + case conn := <-queue: + return conn, nil + case <-l.ctx.Done(): + return nil, errClosed + } } +} - var c *webtransport.Conn - select { - case c = <-l.queue: - // TODO: libp2p handshake - // TODO: pass in transport - return &conn{wconn: c}, nil - case <-l.ctx.Done(): - return nil, errClosed +func (l *listener) handshake(ctx context.Context, c *webtransport.Conn) (tpt.CapableConn, error) { + str, err := c.AcceptStream(ctx) + if err != nil { + return nil, err + } + conn, err := l.noise.SecureInbound(ctx, &webtransportStream{Stream: str, wconn: c}, "") + if err != nil { + return nil, err } + return newConn(l.transport, c, conn.LocalPrivateKey(), conn.RemotePublicKey()) } func (l *listener) Addr() net.Addr { diff --git a/p2p/transport/webtransport/stream.go b/p2p/transport/webtransport/stream.go index cebfcfb2fb..fa34242b6c 100644 --- a/p2p/transport/webtransport/stream.go +++ b/p2p/transport/webtransport/stream.go @@ -2,6 +2,7 @@ package libp2pwebtransport import ( "errors" + "net" "github.com/marten-seemann/webtransport-go" @@ -12,6 +13,21 @@ const ( reset webtransport.ErrorCode = 0 ) +type webtransportStream struct { + webtransport.Stream + wconn *webtransport.Conn +} + +var _ net.Conn = &webtransportStream{} + +func (s *webtransportStream) LocalAddr() net.Addr { + return s.wconn.LocalAddr() +} + +func (s *webtransportStream) RemoteAddr() net.Addr { + return s.wconn.RemoteAddr() +} + type stream struct { webtransport.Stream } diff --git a/p2p/transport/webtransport/transport.go b/p2p/transport/webtransport/transport.go index 1c91413c1b..7a0098fe77 100644 --- a/p2p/transport/webtransport/transport.go +++ b/p2p/transport/webtransport/transport.go @@ -8,8 +8,10 @@ import ( "github.com/multiformats/go-multihash" "sync" + ic "github.com/libp2p/go-libp2p-core/crypto" "github.com/libp2p/go-libp2p-core/peer" tpt "github.com/libp2p/go-libp2p-core/transport" + noise "github.com/libp2p/go-libp2p-noise" logging "github.com/ipfs/go-log/v2" "github.com/marten-seemann/webtransport-go" @@ -22,26 +24,44 @@ var log = logging.Logger("webtransport") const webtransportHTTPEndpoint = "/.well-known/libp2p-webtransport" +const maxProtoSize = 8 << 10 + type transport struct { + privKey ic.PrivKey + pid peer.ID + tlsConf *tls.Config dialer webtransport.Dialer initOnce sync.Once server webtransport.Server + + noise *noise.Transport } var _ tpt.Transport = &transport{} -func New() (tpt.Transport, error) { +func New(key ic.PrivKey) (tpt.Transport, error) { + id, err := peer.IDFromPrivateKey(key) + if err != nil { + return nil, err + } tlsConf, err := getTLSConf() // TODO: only do this when initializing a listener if err != nil { return nil, err } + noise, err := noise.New(key) + if err != nil { + return nil, err + } return &transport{ + pid: id, + privKey: key, tlsConf: tlsConf, dialer: webtransport.Dialer{ TLSClientConf: &tls.Config{InsecureSkipVerify: true}, // TODO: verify certificate, }, + noise: noise, }, nil } @@ -78,11 +98,16 @@ func (t *transport) Dial(ctx context.Context, raddr ma.Multiaddr, p peer.ID) (tp if rsp.StatusCode < 200 || rsp.StatusCode > 299 { return nil, fmt.Errorf("invalid response status code: %d", rsp.StatusCode) } - // TODO: run handshake on conn - return &conn{ - transport: t, - wconn: wconn, - }, nil + str, err := wconn.OpenStreamSync(ctx) + if err != nil { + return nil, err + } + // TODO: use early data and verify the cert hash + sconn, err := t.noise.SecureOutbound(ctx, &webtransportStream{Stream: str, wconn: wconn}, p) + if err != nil { + return nil, err + } + return newConn(t, wconn, t.privKey, sconn.RemotePublicKey()) } var dialMatcher = mafmt.And(mafmt.IP, mafmt.Base(ma.P_UDP), mafmt.Base(ma.P_QUIC), mafmt.Base(ma.P_WEBTRANSPORT)) @@ -92,7 +117,7 @@ func (t *transport) CanDial(addr ma.Multiaddr) bool { } func (t *transport) Listen(laddr ma.Multiaddr) (tpt.Listener, error) { - return newListener(laddr, t.tlsConf) + return newListener(laddr, t.tlsConf, t, t.noise) } func (t *transport) Protocols() []int { diff --git a/p2p/transport/webtransport/transport_test.go b/p2p/transport/webtransport/transport_test.go index bf65017dcf..338d26ab55 100644 --- a/p2p/transport/webtransport/transport_test.go +++ b/p2p/transport/webtransport/transport_test.go @@ -2,24 +2,34 @@ package libp2pwebtransport_test import ( "context" + "crypto/rand" "io" "testing" + ic "github.com/libp2p/go-libp2p-core/crypto" + "github.com/libp2p/go-libp2p-core/peer" + libp2pwebtransport "github.com/marten-seemann/go-libp2p-webtransport" ma "github.com/multiformats/go-multiaddr" "github.com/stretchr/testify/require" ) func TestTransport(t *testing.T) { - tr, err := libp2pwebtransport.New() + serverKey, _, err := ic.GenerateEd25519Key(rand.Reader) + require.NoError(t, err) + serverID, err := peer.IDFromPrivateKey(serverKey) + require.NoError(t, err) + tr, err := libp2pwebtransport.New(serverKey) require.NoError(t, err) ln, err := tr.Listen(ma.StringCast("/ip4/127.0.0.1/udp/0/quic/webtransport")) require.NoError(t, err) go func() { - tr2, err := libp2pwebtransport.New() + clientKey, _, err := ic.GenerateEd25519Key(rand.Reader) + require.NoError(t, err) + tr2, err := libp2pwebtransport.New(clientKey) require.NoError(t, err) - conn, err := tr2.Dial(context.Background(), ln.Multiaddr(), "peer") + conn, err := tr2.Dial(context.Background(), ln.Multiaddr(), serverID) require.NoError(t, err) str, err := conn.OpenStream(context.Background()) require.NoError(t, err) From 1ea33604b99cd4be9848d5b1847e29f826678bf0 Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Sat, 16 Apr 2022 18:30:03 +0100 Subject: [PATCH 04/44] generate a single certificate, not a chain --- p2p/transport/webtransport/crypto.go | 49 +++++++--------------------- 1 file changed, 12 insertions(+), 37 deletions(-) diff --git a/p2p/transport/webtransport/crypto.go b/p2p/transport/webtransport/crypto.go index f658eee3a5..5c6dcdb102 100644 --- a/p2p/transport/webtransport/crypto.go +++ b/p2p/transport/webtransport/crypto.go @@ -8,6 +8,7 @@ import ( "crypto/tls" "crypto/x509" "crypto/x509/pkix" + "encoding/binary" "math/big" "time" ) @@ -16,31 +17,29 @@ func certificateHash(c *tls.Config) [32]byte { return sha256.Sum256(c.Certificates[0].Certificate[0]) } -// TODO: don't generate a chain here. A single cert should be sufficient (and save bytes). func getTLSConf() (*tls.Config, error) { - ca, caPrivateKey, err := generateCA() + cert, priv, err := generateCert() if err != nil { return nil, err } - leafCert, leafPrivateKey, err := generateLeafCert(ca, caPrivateKey) - if err != nil { - return nil, err - } - certPool := x509.NewCertPool() - certPool.AddCert(ca) return &tls.Config{ Certificates: []tls.Certificate{{ - Certificate: [][]byte{leafCert.Raw}, - PrivateKey: leafPrivateKey, + Certificate: [][]byte{cert.Raw}, + PrivateKey: priv, }}, }, nil } -func generateCA() (*x509.Certificate, *ecdsa.PrivateKey, error) { +func generateCert() (*x509.Certificate, *ecdsa.PrivateKey, error) { + b := make([]byte, 8) + if _, err := rand.Read(b); err != nil { + return nil, nil, err + } + serial := binary.BigEndian.Uint64(b) certTempl := &x509.Certificate{ - SerialNumber: big.NewInt(2019), + SerialNumber: big.NewInt(int64(serial)), Subject: pkix.Name{}, - NotBefore: time.Now().Add(-time.Hour), + NotBefore: time.Now(), NotAfter: time.Now().Add(24 * time.Hour), IsCA: true, ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth, x509.ExtKeyUsageServerAuth}, @@ -61,27 +60,3 @@ func generateCA() (*x509.Certificate, *ecdsa.PrivateKey, error) { } return ca, caPrivateKey, nil } - -func generateLeafCert(ca *x509.Certificate, caPrivateKey *ecdsa.PrivateKey) (*x509.Certificate, *ecdsa.PrivateKey, error) { - certTempl := &x509.Certificate{ - SerialNumber: big.NewInt(1), - DNSNames: []string{"localhost"}, - NotBefore: time.Now().Add(-time.Hour), - NotAfter: time.Now().Add(24 * time.Hour), - ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth, x509.ExtKeyUsageServerAuth}, - KeyUsage: x509.KeyUsageDigitalSignature, - } - privKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) - if err != nil { - return nil, nil, err - } - certBytes, err := x509.CreateCertificate(rand.Reader, certTempl, ca, &privKey.PublicKey, caPrivateKey) - if err != nil { - return nil, nil, err - } - cert, err := x509.ParseCertificate(certBytes) - if err != nil { - return nil, nil, err - } - return cert, privKey, nil -} From a077c8901ffc82989106bfd577ce99dfaef318aa Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Sun, 17 Apr 2022 11:45:54 +0100 Subject: [PATCH 05/44] implement a basic certificate manager --- p2p/transport/webtransport/cert_manager.go | 164 ++++++++++++++++++ .../webtransport/cert_manager_test.go | 72 ++++++++ p2p/transport/webtransport/crypto.go | 11 +- p2p/transport/webtransport/transport.go | 6 +- 4 files changed, 247 insertions(+), 6 deletions(-) create mode 100644 p2p/transport/webtransport/cert_manager.go create mode 100644 p2p/transport/webtransport/cert_manager_test.go diff --git a/p2p/transport/webtransport/cert_manager.go b/p2p/transport/webtransport/cert_manager.go new file mode 100644 index 0000000000..83fc2878cf --- /dev/null +++ b/p2p/transport/webtransport/cert_manager.go @@ -0,0 +1,164 @@ +package libp2pwebtransport + +import ( + "context" + "crypto/sha256" + "crypto/tls" + "crypto/x509" + "sync" + "time" + + ma "github.com/multiformats/go-multiaddr" + "github.com/multiformats/go-multibase" + "github.com/multiformats/go-multihash" +) + +type certConfig struct { + start, end time.Time + tlsConf *tls.Config +} + +func newCertConfig(start, end time.Time, conf *tls.Config) (*certConfig, error) { + return &certConfig{ + start: start, + end: end, + tlsConf: conf, + }, nil +} + +// Certificate renewal logic: +// 0. To simplify the math, assume the certificate is valid for 10 days (in real life: 14 days). +// 1. On startup, we generate the first certificate (1). +// 2. After 4 days, we generate a second certificate (2). +// We don't use that certificate yet, but we advertise the hashes of (1) and (2). +// That allows clients to connect to us using addresses that are 4 days old. +// 3. After another 4 days, we now actually start using (2). +// We also generate a third certificate (3), and start advertising the hashes of (2) and (3). +// We continue to remember the hash of (1) for validation during the Noise handshake for another 4 days, +// as the client might be connecting with a cached address. +type certManager struct { + ctx context.Context + ctxCancel context.CancelFunc + refCount sync.WaitGroup + + certValidity time.Duration // so we can set it in tests + + mx sync.Mutex + lastConfig *certConfig // initially nil + currentConfig *certConfig + nextConfig *certConfig // nil until we have passed half the certValidity of the current config + addrComp ma.Multiaddr +} + +func newCertManager(certValidity time.Duration) (*certManager, error) { + m := &certManager{ + certValidity: certValidity, + } + m.ctx, m.ctxCancel = context.WithCancel(context.Background()) + if err := m.init(); err != nil { + return nil, err + } + m.refCount.Add(1) + go func() { + defer m.refCount.Done() + if err := m.background(); err != nil { + log.Fatal(err) + } + }() + return m, nil +} + +func (m *certManager) init() error { + start := time.Now() + end := start.Add(m.certValidity) + tlsConf, err := getTLSConf(start, end) + if err != nil { + return err + } + cc, err := newCertConfig(start, end, tlsConf) + if err != nil { + return err + } + m.currentConfig = cc + return m.cacheAddrComponent() +} + +func (m *certManager) background() error { + t := time.NewTicker(m.certValidity * 4 / 9) // make sure we're a bit faster than 1/2 + defer t.Stop() + + for { + select { + case <-m.ctx.Done(): + return nil + case start := <-t.C: + end := start.Add(m.certValidity) + tlsConf, err := getTLSConf(start, end) + if err != nil { + return err + } + cc, err := newCertConfig(start, end, tlsConf) + if err != nil { + return err + } + m.mx.Lock() + if m.nextConfig != nil { + m.lastConfig = m.currentConfig + m.currentConfig = m.nextConfig + } + m.nextConfig = cc + if err := m.cacheAddrComponent(); err != nil { + m.mx.Unlock() + return err + } + m.mx.Unlock() + } + } +} + +func (m *certManager) GetConfig() *tls.Config { + m.mx.Lock() + defer m.mx.Unlock() + return m.currentConfig.tlsConf +} + +func (m *certManager) AddrComponent() ma.Multiaddr { + m.mx.Lock() + defer m.mx.Unlock() + return m.addrComp +} + +func (m *certManager) cacheAddrComponent() error { + addr, err := m.addrComponentForCert(m.currentConfig.tlsConf.Certificates[0].Leaf) + if err != nil { + return err + } + if m.nextConfig != nil { + comp, err := m.addrComponentForCert(m.nextConfig.tlsConf.Certificates[0].Leaf) + if err != nil { + return err + } + addr = addr.Encapsulate(comp) + } + m.addrComp = addr + return nil +} + +func (m *certManager) addrComponentForCert(cert *x509.Certificate) (ma.Multiaddr, error) { + hash := sha256.Sum256(cert.Raw) + mh, err := multihash.Encode(hash[:], multihash.SHA2_256) + if err != nil { + return nil, err + } + certStr, err := multibase.Encode(multibase.Base58BTC, mh) + if err != nil { + return nil, err + } + return ma.NewComponent(ma.ProtocolWithCode(ma.P_CERTHASH).Name, certStr) +} + +func (m *certManager) Close() error { + m.ctxCancel() + m.refCount.Wait() + return nil +} diff --git a/p2p/transport/webtransport/cert_manager_test.go b/p2p/transport/webtransport/cert_manager_test.go new file mode 100644 index 0000000000..6b0a9d5795 --- /dev/null +++ b/p2p/transport/webtransport/cert_manager_test.go @@ -0,0 +1,72 @@ +package libp2pwebtransport + +import ( + "github.com/multiformats/go-multibase" + "github.com/multiformats/go-multihash" + "testing" + "time" + + ma "github.com/multiformats/go-multiaddr" + "github.com/stretchr/testify/require" +) + +func splitMultiaddr(addr ma.Multiaddr) []ma.Component { + var components []ma.Component + ma.ForEach(addr, func(c ma.Component) bool { + components = append(components, c) + return true + }) + return components +} + +func certHashFromComponent(t *testing.T, comp ma.Component) []byte { + t.Helper() + _, data, err := multibase.Decode(comp.Value()) + require.NoError(t, err) + mh, err := multihash.Decode(data) + require.NoError(t, err) + require.Equal(t, uint64(multihash.SHA2_256), mh.Code) + return mh.Digest +} + +func TestInitialCert(t *testing.T) { + m, err := newCertManager(certValidity) + require.NoError(t, err) + defer m.Close() + + conf := m.GetConfig() + require.Len(t, conf.Certificates, 1) + cert := conf.Certificates[0] + require.WithinDuration(t, time.Now(), cert.Leaf.NotBefore, time.Second) + require.WithinDuration(t, time.Now().Add(certValidity), cert.Leaf.NotAfter, time.Second) + addr := m.AddrComponent() + components := splitMultiaddr(addr) + require.Len(t, components, 1) + require.Equal(t, ma.P_CERTHASH, components[0].Protocol().Code) + hash := certificateHash(conf) + require.Equal(t, hash[:], certHashFromComponent(t, components[0])) +} + +func TestCertRenewal(t *testing.T) { + const certValidity = 300 * time.Millisecond + m, err := newCertManager(certValidity) + require.NoError(t, err) + defer m.Close() + + firstConf := m.GetConfig() + require.Len(t, splitMultiaddr(m.AddrComponent()), 1) + // wait for a new certificate to be generated + require.Eventually(t, func() bool { return len(splitMultiaddr(m.AddrComponent())) > 1 }, certValidity/2, 10*time.Millisecond) + // the actual config used should still be the same, we're just advertising the hash of the next config + components := splitMultiaddr(m.AddrComponent()) + require.Len(t, components, 2) + for _, c := range components { + require.Equal(t, ma.P_CERTHASH, c.Protocol().Code) + } + require.Equal(t, firstConf, m.GetConfig()) + require.Eventually(t, func() bool { return m.GetConfig() != firstConf }, certValidity/2, 10*time.Millisecond) + newConf := m.GetConfig() + // check that the new config now matches the second component + hash := certificateHash(newConf) + require.Equal(t, hash[:], certHashFromComponent(t, components[1])) +} diff --git a/p2p/transport/webtransport/crypto.go b/p2p/transport/webtransport/crypto.go index 5c6dcdb102..f30943c6f6 100644 --- a/p2p/transport/webtransport/crypto.go +++ b/p2p/transport/webtransport/crypto.go @@ -17,8 +17,8 @@ func certificateHash(c *tls.Config) [32]byte { return sha256.Sum256(c.Certificates[0].Certificate[0]) } -func getTLSConf() (*tls.Config, error) { - cert, priv, err := generateCert() +func getTLSConf(start, end time.Time) (*tls.Config, error) { + cert, priv, err := generateCert(start, end) if err != nil { return nil, err } @@ -26,11 +26,12 @@ func getTLSConf() (*tls.Config, error) { Certificates: []tls.Certificate{{ Certificate: [][]byte{cert.Raw}, PrivateKey: priv, + Leaf: cert, }}, }, nil } -func generateCert() (*x509.Certificate, *ecdsa.PrivateKey, error) { +func generateCert(start, end time.Time) (*x509.Certificate, *ecdsa.PrivateKey, error) { b := make([]byte, 8) if _, err := rand.Read(b); err != nil { return nil, nil, err @@ -39,8 +40,8 @@ func generateCert() (*x509.Certificate, *ecdsa.PrivateKey, error) { certTempl := &x509.Certificate{ SerialNumber: big.NewInt(int64(serial)), Subject: pkix.Name{}, - NotBefore: time.Now(), - NotAfter: time.Now().Add(24 * time.Hour), + NotBefore: start, + NotAfter: end, IsCA: true, ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth, x509.ExtKeyUsageServerAuth}, KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageCertSign, diff --git a/p2p/transport/webtransport/transport.go b/p2p/transport/webtransport/transport.go index 7a0098fe77..4a2700b85f 100644 --- a/p2p/transport/webtransport/transport.go +++ b/p2p/transport/webtransport/transport.go @@ -7,6 +7,7 @@ import ( "github.com/multiformats/go-multibase" "github.com/multiformats/go-multihash" "sync" + "time" ic "github.com/libp2p/go-libp2p-core/crypto" "github.com/libp2p/go-libp2p-core/peer" @@ -26,6 +27,8 @@ const webtransportHTTPEndpoint = "/.well-known/libp2p-webtransport" const maxProtoSize = 8 << 10 +const certValidity = 14 * 24 * time.Hour + type transport struct { privKey ic.PrivKey pid peer.ID @@ -46,7 +49,8 @@ func New(key ic.PrivKey) (tpt.Transport, error) { if err != nil { return nil, err } - tlsConf, err := getTLSConf() // TODO: only do this when initializing a listener + now := time.Now() + tlsConf, err := getTLSConf(now, now.Add(certValidity)) // TODO: only do this when initializing a listener if err != nil { return nil, err } From 268725c30cf6cd437176b8fe9253d21ea07c78b7 Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Mon, 18 Apr 2022 23:08:26 +0100 Subject: [PATCH 06/44] check that addr contains at least one certhash in CanDial --- p2p/transport/webtransport/transport.go | 17 ++++++- p2p/transport/webtransport/transport_test.go | 49 ++++++++++++++++++-- 2 files changed, 59 insertions(+), 7 deletions(-) diff --git a/p2p/transport/webtransport/transport.go b/p2p/transport/webtransport/transport.go index 4a2700b85f..dba619ebf1 100644 --- a/p2p/transport/webtransport/transport.go +++ b/p2p/transport/webtransport/transport.go @@ -4,8 +4,6 @@ import ( "context" "crypto/tls" "fmt" - "github.com/multiformats/go-multibase" - "github.com/multiformats/go-multihash" "sync" "time" @@ -19,6 +17,8 @@ import ( ma "github.com/multiformats/go-multiaddr" mafmt "github.com/multiformats/go-multiaddr-fmt" manet "github.com/multiformats/go-multiaddr/net" + "github.com/multiformats/go-multibase" + "github.com/multiformats/go-multihash" ) var log = logging.Logger("webtransport") @@ -117,6 +117,19 @@ func (t *transport) Dial(ctx context.Context, raddr ma.Multiaddr, p peer.ID) (tp var dialMatcher = mafmt.And(mafmt.IP, mafmt.Base(ma.P_UDP), mafmt.Base(ma.P_QUIC), mafmt.Base(ma.P_WEBTRANSPORT)) func (t *transport) CanDial(addr ma.Multiaddr) bool { + var numHashes int + ma.ForEach(addr, func(c ma.Component) bool { + if c.Protocol().Code == ma.P_CERTHASH { + numHashes++ + } + return true + }) + if numHashes == 0 { + return false + } + for i := 0; i < numHashes; i++ { + addr, _ = ma.SplitLast(addr) + } return dialMatcher.Matches(addr) } diff --git a/p2p/transport/webtransport/transport_test.go b/p2p/transport/webtransport/transport_test.go index 338d26ab55..a03771f9e8 100644 --- a/p2p/transport/webtransport/transport_test.go +++ b/p2p/transport/webtransport/transport_test.go @@ -3,6 +3,7 @@ package libp2pwebtransport_test import ( "context" "crypto/rand" + "fmt" "io" "testing" @@ -11,22 +12,27 @@ import ( libp2pwebtransport "github.com/marten-seemann/go-libp2p-webtransport" ma "github.com/multiformats/go-multiaddr" + "github.com/multiformats/go-multibase" "github.com/stretchr/testify/require" ) -func TestTransport(t *testing.T) { - serverKey, _, err := ic.GenerateEd25519Key(rand.Reader) +func newIdentity(t *testing.T) (peer.ID, ic.PrivKey) { + key, _, err := ic.GenerateEd25519Key(rand.Reader) require.NoError(t, err) - serverID, err := peer.IDFromPrivateKey(serverKey) + id, err := peer.IDFromPrivateKey(key) require.NoError(t, err) + return id, key +} + +func TestTransport(t *testing.T) { + serverID, serverKey := newIdentity(t) tr, err := libp2pwebtransport.New(serverKey) require.NoError(t, err) ln, err := tr.Listen(ma.StringCast("/ip4/127.0.0.1/udp/0/quic/webtransport")) require.NoError(t, err) go func() { - clientKey, _, err := ic.GenerateEd25519Key(rand.Reader) - require.NoError(t, err) + _, clientKey := newIdentity(t) tr2, err := libp2pwebtransport.New(clientKey) require.NoError(t, err) conn, err := tr2.Dial(context.Background(), ln.Multiaddr(), serverID) @@ -46,3 +52,36 @@ func TestTransport(t *testing.T) { require.NoError(t, err) require.Equal(t, "foobar", string(data)) } + +func TestCanDial(t *testing.T) { + randomHash := func(t *testing.T) string { + b := make([]byte, 16) + rand.Read(b) + s, err := multibase.Encode(multibase.Base32hex, b) + require.NoError(t, err) + return s + } + + valid := []ma.Multiaddr{ + ma.StringCast("/ip4/127.0.0.1/udp/1234/quic/webtransport/certhash/" + randomHash(t)), + ma.StringCast("/ip6/b16b:8255:efc6:9cd5:1a54:ee86:2d7a:c2e6/udp/1234/quic/webtransport/certhash/" + randomHash(t)), + ma.StringCast(fmt.Sprintf("/ip4/127.0.0.1/udp/1234/quic/webtransport/certhash/%s/certhash/%s/certhash/%s", randomHash(t), randomHash(t), randomHash(t))), + } + + invalid := []ma.Multiaddr{ + ma.StringCast("/ip4/127.0.0.1/udp/1234"), // missing webtransport + ma.StringCast("/ip4/127.0.0.1/udp/1234/webtransport"), // missing quic + ma.StringCast("/ip4/127.0.0.1/tcp/1234/webtransport"), // WebTransport over TCP? Is this a joke? + ma.StringCast("/ip4/127.0.0.1/udp/1234/quic/webtransport"), // missing certificate hash + } + + _, key := newIdentity(t) + tr, err := libp2pwebtransport.New(key) + require.NoError(t, err) + for _, addr := range valid { + require.Truef(t, tr.CanDial(addr), "expected to be able to dial %s", addr) + } + for _, addr := range invalid { + require.Falsef(t, tr.CanDial(addr), "expected to not be able to dial %s", addr) + } +} From 2d295b586d5f71fc4fb07f867ad6fb04338914b0 Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Mon, 18 Apr 2022 23:16:41 +0100 Subject: [PATCH 07/44] fix copying of http.Server mutex in listener constructor --- p2p/transport/webtransport/listener.go | 21 ++++++++++----------- 1 file changed, 10 insertions(+), 11 deletions(-) diff --git a/p2p/transport/webtransport/listener.go b/p2p/transport/webtransport/listener.go index 0dadec2fa8..6ead26e01c 100644 --- a/p2p/transport/webtransport/listener.go +++ b/p2p/transport/webtransport/listener.go @@ -4,6 +4,7 @@ import ( "context" "crypto/tls" "errors" + "github.com/lucas-clemente/quic-go/http3" "net" "net/http" "time" @@ -12,7 +13,6 @@ import ( noise "github.com/libp2p/go-libp2p-noise" - "github.com/lucas-clemente/quic-go/http3" "github.com/marten-seemann/webtransport-go" ma "github.com/multiformats/go-multiaddr" manet "github.com/multiformats/go-multiaddr/net" @@ -70,22 +70,22 @@ func newListener(laddr ma.Multiaddr, tlsConf *tls.Config, transport tpt.Transpor addr: udpConn.LocalAddr(), tlsConf: tlsConf, multiaddr: localMultiaddr, - } - ln.ctx, ln.ctxCancel = context.WithCancel(context.Background()) - server := webtransport.Server{ - H3: http3.Server{ - Server: &http.Server{ - TLSConfig: tlsConf, + server: webtransport.Server{ + H3: http3.Server{ + Server: &http.Server{ + TLSConfig: tlsConf, + }, }, }, } + ln.ctx, ln.ctxCancel = context.WithCancel(context.Background()) mux := http.NewServeMux() mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { w.Write([]byte("Hello, world!")) }) mux.HandleFunc(webtransportHTTPEndpoint, func(w http.ResponseWriter, r *http.Request) { // TODO: check ?type=multistream URL param - c, err := server.Upgrade(w, r) + c, err := ln.server.Upgrade(w, r) if err != nil { w.WriteHeader(500) return @@ -93,16 +93,15 @@ func newListener(laddr ma.Multiaddr, tlsConf *tls.Config, transport tpt.Transpor // TODO: handle queue overflow ln.queue <- c }) - server.H3.Handler = mux + ln.server.H3.Handler = mux go func() { defer close(ln.serverClosed) defer func() { udpConn.Close() }() - if err := server.Serve(udpConn); err != nil { + if err := ln.server.Serve(udpConn); err != nil { // TODO: only output if the server hasn't been closed log.Debugw("serving failed", "addr", udpConn.LocalAddr(), "error", err) } }() - ln.server = server return ln, nil } From 59fc6efc0e38f37479911ba80b871ff44880b78f Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Mon, 18 Apr 2022 23:19:22 +0100 Subject: [PATCH 08/44] fix staticcheck --- p2p/transport/webtransport/conn.go | 9 ++++----- p2p/transport/webtransport/multiaddr.go | 4 ---- p2p/transport/webtransport/transport.go | 7 +------ 3 files changed, 5 insertions(+), 15 deletions(-) diff --git a/p2p/transport/webtransport/conn.go b/p2p/transport/webtransport/conn.go index e45daedd86..c1b5742a47 100644 --- a/p2p/transport/webtransport/conn.go +++ b/p2p/transport/webtransport/conn.go @@ -3,7 +3,6 @@ package libp2pwebtransport import ( "context" - "github.com/libp2p/go-libp2p-core/crypto" ic "github.com/libp2p/go-libp2p-core/crypto" "github.com/libp2p/go-libp2p-core/network" "github.com/libp2p/go-libp2p-core/peer" @@ -61,10 +60,10 @@ func (c *conn) AcceptStream() (network.MuxedStream, error) { return &stream{str}, err } -func (c *conn) LocalPeer() peer.ID { return c.localPeer } -func (c *conn) LocalPrivateKey() crypto.PrivKey { return c.privKey } -func (c *conn) RemotePeer() peer.ID { return c.remotePeer } -func (c *conn) RemotePublicKey() crypto.PubKey { return c.remotePubKey } +func (c *conn) LocalPeer() peer.ID { return c.localPeer } +func (c *conn) LocalPrivateKey() ic.PrivKey { return c.privKey } +func (c *conn) RemotePeer() peer.ID { return c.remotePeer } +func (c *conn) RemotePublicKey() ic.PubKey { return c.remotePubKey } func (c *conn) LocalMultiaddr() ma.Multiaddr { // TODO implement me diff --git a/p2p/transport/webtransport/multiaddr.go b/p2p/transport/webtransport/multiaddr.go index a84853c1ee..3c281e0859 100644 --- a/p2p/transport/webtransport/multiaddr.go +++ b/p2p/transport/webtransport/multiaddr.go @@ -16,7 +16,3 @@ func toWebtransportMultiaddr(na net.Addr) (ma.Multiaddr, error) { } return udpMA.Encapsulate(webtransportMA), nil } - -func fromWebtransportMultiaddr(addr ma.Multiaddr) (net.Addr, error) { - return manet.ToNetAddr(addr.Decapsulate(webtransportMA)) -} diff --git a/p2p/transport/webtransport/transport.go b/p2p/transport/webtransport/transport.go index dba619ebf1..0c32f8fee3 100644 --- a/p2p/transport/webtransport/transport.go +++ b/p2p/transport/webtransport/transport.go @@ -4,7 +4,6 @@ import ( "context" "crypto/tls" "fmt" - "sync" "time" ic "github.com/libp2p/go-libp2p-core/crypto" @@ -25,8 +24,6 @@ var log = logging.Logger("webtransport") const webtransportHTTPEndpoint = "/.well-known/libp2p-webtransport" -const maxProtoSize = 8 << 10 - const certValidity = 14 * 24 * time.Hour type transport struct { @@ -36,9 +33,6 @@ type transport struct { tlsConf *tls.Config dialer webtransport.Dialer - initOnce sync.Once - server webtransport.Server - noise *noise.Transport } @@ -107,6 +101,7 @@ func (t *transport) Dial(ctx context.Context, raddr ma.Multiaddr, p peer.ID) (tp return nil, err } // TODO: use early data and verify the cert hash + _ = certHashes sconn, err := t.noise.SecureOutbound(ctx, &webtransportStream{Stream: str, wconn: wconn}, p) if err != nil { return nil, err From 85384b744468334f256635ecbb99e9b5c33ddec8 Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Mon, 18 Apr 2022 23:48:17 +0100 Subject: [PATCH 09/44] fix flaky TestCertRenewal test on CI --- p2p/transport/webtransport/cert_manager_test.go | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/p2p/transport/webtransport/cert_manager_test.go b/p2p/transport/webtransport/cert_manager_test.go index 6b0a9d5795..3d9c881d4e 100644 --- a/p2p/transport/webtransport/cert_manager_test.go +++ b/p2p/transport/webtransport/cert_manager_test.go @@ -3,6 +3,7 @@ package libp2pwebtransport import ( "github.com/multiformats/go-multibase" "github.com/multiformats/go-multihash" + "os" "testing" "time" @@ -48,7 +49,10 @@ func TestInitialCert(t *testing.T) { } func TestCertRenewal(t *testing.T) { - const certValidity = 300 * time.Millisecond + var certValidity = 300 * time.Millisecond + if os.Getenv("CI") != "" { + certValidity = 2 * time.Second + } m, err := newCertManager(certValidity) require.NoError(t, err) defer m.Close() From 77cdde610e75bd3a64b608fbc76ec6d81ca824b2 Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Tue, 19 Apr 2022 00:00:27 +0100 Subject: [PATCH 10/44] check validity of listen addresses --- p2p/transport/webtransport/multiaddr.go | 3 ++ p2p/transport/webtransport/transport.go | 8 ++-- p2p/transport/webtransport/transport_test.go | 49 +++++++++++++++----- 3 files changed, 45 insertions(+), 15 deletions(-) diff --git a/p2p/transport/webtransport/multiaddr.go b/p2p/transport/webtransport/multiaddr.go index 3c281e0859..5560d90f68 100644 --- a/p2p/transport/webtransport/multiaddr.go +++ b/p2p/transport/webtransport/multiaddr.go @@ -4,11 +4,14 @@ import ( "net" ma "github.com/multiformats/go-multiaddr" + mafmt "github.com/multiformats/go-multiaddr-fmt" manet "github.com/multiformats/go-multiaddr/net" ) var webtransportMA = ma.StringCast("/quic/webtransport") +var webtransportMatcher = mafmt.And(mafmt.IP, mafmt.Base(ma.P_UDP), mafmt.Base(ma.P_QUIC), mafmt.Base(ma.P_WEBTRANSPORT)) + func toWebtransportMultiaddr(na net.Addr) (ma.Multiaddr, error) { udpMA, err := manet.FromNetAddr(na) if err != nil { diff --git a/p2p/transport/webtransport/transport.go b/p2p/transport/webtransport/transport.go index 0c32f8fee3..4e1efb6cac 100644 --- a/p2p/transport/webtransport/transport.go +++ b/p2p/transport/webtransport/transport.go @@ -14,7 +14,6 @@ import ( logging "github.com/ipfs/go-log/v2" "github.com/marten-seemann/webtransport-go" ma "github.com/multiformats/go-multiaddr" - mafmt "github.com/multiformats/go-multiaddr-fmt" manet "github.com/multiformats/go-multiaddr/net" "github.com/multiformats/go-multibase" "github.com/multiformats/go-multihash" @@ -109,8 +108,6 @@ func (t *transport) Dial(ctx context.Context, raddr ma.Multiaddr, p peer.ID) (tp return newConn(t, wconn, t.privKey, sconn.RemotePublicKey()) } -var dialMatcher = mafmt.And(mafmt.IP, mafmt.Base(ma.P_UDP), mafmt.Base(ma.P_QUIC), mafmt.Base(ma.P_WEBTRANSPORT)) - func (t *transport) CanDial(addr ma.Multiaddr) bool { var numHashes int ma.ForEach(addr, func(c ma.Component) bool { @@ -125,10 +122,13 @@ func (t *transport) CanDial(addr ma.Multiaddr) bool { for i := 0; i < numHashes; i++ { addr, _ = ma.SplitLast(addr) } - return dialMatcher.Matches(addr) + return webtransportMatcher.Matches(addr) } func (t *transport) Listen(laddr ma.Multiaddr) (tpt.Listener, error) { + if !webtransportMatcher.Matches(laddr) { + return nil, fmt.Errorf("cannot listen on non-WebTransport addr: %s", laddr) + } return newListener(laddr, t.tlsConf, t, t.noise) } diff --git a/p2p/transport/webtransport/transport_test.go b/p2p/transport/webtransport/transport_test.go index a03771f9e8..bddd23689c 100644 --- a/p2p/transport/webtransport/transport_test.go +++ b/p2p/transport/webtransport/transport_test.go @@ -24,6 +24,14 @@ func newIdentity(t *testing.T) (peer.ID, ic.PrivKey) { return id, key } +func randomMultihash(t *testing.T) string { + b := make([]byte, 16) + rand.Read(b) + s, err := multibase.Encode(multibase.Base32hex, b) + require.NoError(t, err) + return s +} + func TestTransport(t *testing.T) { serverID, serverKey := newIdentity(t) tr, err := libp2pwebtransport.New(serverKey) @@ -54,18 +62,10 @@ func TestTransport(t *testing.T) { } func TestCanDial(t *testing.T) { - randomHash := func(t *testing.T) string { - b := make([]byte, 16) - rand.Read(b) - s, err := multibase.Encode(multibase.Base32hex, b) - require.NoError(t, err) - return s - } - valid := []ma.Multiaddr{ - ma.StringCast("/ip4/127.0.0.1/udp/1234/quic/webtransport/certhash/" + randomHash(t)), - ma.StringCast("/ip6/b16b:8255:efc6:9cd5:1a54:ee86:2d7a:c2e6/udp/1234/quic/webtransport/certhash/" + randomHash(t)), - ma.StringCast(fmt.Sprintf("/ip4/127.0.0.1/udp/1234/quic/webtransport/certhash/%s/certhash/%s/certhash/%s", randomHash(t), randomHash(t), randomHash(t))), + ma.StringCast("/ip4/127.0.0.1/udp/1234/quic/webtransport/certhash/" + randomMultihash(t)), + ma.StringCast("/ip6/b16b:8255:efc6:9cd5:1a54:ee86:2d7a:c2e6/udp/1234/quic/webtransport/certhash/" + randomMultihash(t)), + ma.StringCast(fmt.Sprintf("/ip4/127.0.0.1/udp/1234/quic/webtransport/certhash/%s/certhash/%s/certhash/%s", randomMultihash(t), randomMultihash(t), randomMultihash(t))), } invalid := []ma.Multiaddr{ @@ -85,3 +85,30 @@ func TestCanDial(t *testing.T) { require.Falsef(t, tr.CanDial(addr), "expected to not be able to dial %s", addr) } } + +func TestListenAddrValidity(t *testing.T) { + valid := []ma.Multiaddr{ + ma.StringCast("/ip6/::/udp/0/quic/webtransport/"), + ma.StringCast("/ip4/127.0.0.1/udp/1234/quic/webtransport/"), + } + + invalid := []ma.Multiaddr{ + ma.StringCast("/ip4/127.0.0.1/udp/1234"), // missing webtransport + ma.StringCast("/ip4/127.0.0.1/udp/1234/webtransport"), // missing quic + ma.StringCast("/ip4/127.0.0.1/tcp/1234/webtransport"), // WebTransport over TCP? Is this a joke? + ma.StringCast("/ip4/127.0.0.1/udp/1234/quic/webtransport/certhash/" + randomMultihash(t)), + } + + _, key := newIdentity(t) + tr, err := libp2pwebtransport.New(key) + require.NoError(t, err) + for _, addr := range valid { + ln, err := tr.Listen(addr) + require.NoErrorf(t, err, "expected to be able to listen on %s", addr) + ln.Close() + } + for _, addr := range invalid { + _, err := tr.Listen(addr) + require.Errorf(t, err, "expected to not be able to listen on %s", addr) + } +} From ee0e2d022f7990f29c7fc6700a0ba281c2561cf6 Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Tue, 19 Apr 2022 00:02:22 +0100 Subject: [PATCH 11/44] initialize a certiticate manager and pass it to the listeners --- p2p/transport/webtransport/listener.go | 31 +++++++------------- p2p/transport/webtransport/transport.go | 22 ++++++++------ p2p/transport/webtransport/transport_test.go | 26 ++++++++++++++++ 3 files changed, 50 insertions(+), 29 deletions(-) diff --git a/p2p/transport/webtransport/listener.go b/p2p/transport/webtransport/listener.go index 6ead26e01c..f29256eef5 100644 --- a/p2p/transport/webtransport/listener.go +++ b/p2p/transport/webtransport/listener.go @@ -4,7 +4,6 @@ import ( "context" "crypto/tls" "errors" - "github.com/lucas-clemente/quic-go/http3" "net" "net/http" "time" @@ -13,11 +12,10 @@ import ( noise "github.com/libp2p/go-libp2p-noise" + "github.com/lucas-clemente/quic-go/http3" "github.com/marten-seemann/webtransport-go" ma "github.com/multiformats/go-multiaddr" manet "github.com/multiformats/go-multiaddr/net" - "github.com/multiformats/go-multibase" - "github.com/multiformats/go-multihash" ) var errClosed = errors.New("closed") @@ -26,11 +24,11 @@ const queueLen = 16 const handshakeTimeout = 10 * time.Second type listener struct { - transport tpt.Transport - noise *noise.Transport + transport tpt.Transport + noise *noise.Transport + certManager *certManager - server webtransport.Server - tlsConf *tls.Config + server webtransport.Server ctx context.Context ctxCancel context.CancelFunc @@ -45,7 +43,7 @@ type listener struct { var _ tpt.Listener = &listener{} -func newListener(laddr ma.Multiaddr, tlsConf *tls.Config, transport tpt.Transport, noise *noise.Transport) (tpt.Listener, error) { +func newListener(laddr ma.Multiaddr, transport tpt.Transport, noise *noise.Transport, certManager *certManager) (tpt.Listener, error) { network, addr, err := manet.DialArgs(laddr) if err != nil { return nil, err @@ -65,15 +63,17 @@ func newListener(laddr ma.Multiaddr, tlsConf *tls.Config, transport tpt.Transpor ln := &listener{ transport: transport, noise: noise, + certManager: certManager, queue: make(chan *webtransport.Conn, queueLen), serverClosed: make(chan struct{}), addr: udpConn.LocalAddr(), - tlsConf: tlsConf, multiaddr: localMultiaddr, server: webtransport.Server{ H3: http3.Server{ Server: &http.Server{ - TLSConfig: tlsConf, + TLSConfig: &tls.Config{GetConfigForClient: func(*tls.ClientHelloInfo) (*tls.Config, error) { + return certManager.GetConfig(), nil + }}, }, }, }, @@ -154,16 +154,7 @@ func (l *listener) Addr() net.Addr { } func (l *listener) Multiaddr() ma.Multiaddr { - certHash := certificateHash(l.tlsConf) - h, err := multihash.Encode(certHash[:], multihash.SHA2_256) - if err != nil { - panic(err) - } - certHashStr, err := multibase.Encode(multibase.Base58BTC, h) - if err != nil { - panic(err) - } - return l.multiaddr.Encapsulate(ma.StringCast("/certhash/" + certHashStr)) + return l.multiaddr.Encapsulate(l.certManager.AddrComponent()) } func (l *listener) Close() error { diff --git a/p2p/transport/webtransport/transport.go b/p2p/transport/webtransport/transport.go index 4e1efb6cac..ba47a1d64d 100644 --- a/p2p/transport/webtransport/transport.go +++ b/p2p/transport/webtransport/transport.go @@ -4,6 +4,7 @@ import ( "context" "crypto/tls" "fmt" + "sync" "time" ic "github.com/libp2p/go-libp2p-core/crypto" @@ -29,8 +30,11 @@ type transport struct { privKey ic.PrivKey pid peer.ID - tlsConf *tls.Config - dialer webtransport.Dialer + dialer webtransport.Dialer + + listenOnce sync.Once + listenOnceErr error + certManager *certManager noise *noise.Transport } @@ -42,11 +46,6 @@ func New(key ic.PrivKey) (tpt.Transport, error) { if err != nil { return nil, err } - now := time.Now() - tlsConf, err := getTLSConf(now, now.Add(certValidity)) // TODO: only do this when initializing a listener - if err != nil { - return nil, err - } noise, err := noise.New(key) if err != nil { return nil, err @@ -54,7 +53,6 @@ func New(key ic.PrivKey) (tpt.Transport, error) { return &transport{ pid: id, privKey: key, - tlsConf: tlsConf, dialer: webtransport.Dialer{ TLSClientConf: &tls.Config{InsecureSkipVerify: true}, // TODO: verify certificate, }, @@ -129,7 +127,13 @@ func (t *transport) Listen(laddr ma.Multiaddr) (tpt.Listener, error) { if !webtransportMatcher.Matches(laddr) { return nil, fmt.Errorf("cannot listen on non-WebTransport addr: %s", laddr) } - return newListener(laddr, t.tlsConf, t, t.noise) + t.listenOnce.Do(func() { + t.certManager, t.listenOnceErr = newCertManager(certValidity) + }) + if t.listenOnceErr != nil { + return nil, t.listenOnceErr + } + return newListener(laddr, t, t.noise, t.certManager) } func (t *transport) Protocols() []int { diff --git a/p2p/transport/webtransport/transport_test.go b/p2p/transport/webtransport/transport_test.go index bddd23689c..84a053be77 100644 --- a/p2p/transport/webtransport/transport_test.go +++ b/p2p/transport/webtransport/transport_test.go @@ -32,6 +32,17 @@ func randomMultihash(t *testing.T) string { return s } +func extractCertHashes(t *testing.T, addr ma.Multiaddr) []string { + var certHashesStr []string + ma.ForEach(addr, func(c ma.Component) bool { + if c.Protocol().Code == ma.P_CERTHASH { + certHashesStr = append(certHashesStr, c.Value()) + } + return true + }) + return certHashesStr +} + func TestTransport(t *testing.T) { serverID, serverKey := newIdentity(t) tr, err := libp2pwebtransport.New(serverKey) @@ -112,3 +123,18 @@ func TestListenAddrValidity(t *testing.T) { require.Errorf(t, err, "expected to not be able to listen on %s", addr) } } + +func TestListenerAddrs(t *testing.T) { + _, key := newIdentity(t) + tr, err := libp2pwebtransport.New(key) + require.NoError(t, err) + + ln1, err := tr.Listen(ma.StringCast("/ip4/127.0.0.1/udp/0/quic/webtransport")) + require.NoError(t, err) + ln2, err := tr.Listen(ma.StringCast("/ip4/127.0.0.1/udp/0/quic/webtransport")) + require.NoError(t, err) + hashes1 := extractCertHashes(t, ln1.Multiaddr()) + require.Len(t, hashes1, 1) + hashes2 := extractCertHashes(t, ln2.Multiaddr()) + require.Equal(t, hashes1, hashes2) +} From ae21c6c8a0a38ef715cb8c23edc1b4fdf692aaf1 Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Tue, 19 Apr 2022 22:23:22 +0100 Subject: [PATCH 12/44] implement a Close method for the transport --- p2p/transport/webtransport/transport.go | 10 ++++++++++ p2p/transport/webtransport/transport_test.go | 8 ++++++++ 2 files changed, 18 insertions(+) diff --git a/p2p/transport/webtransport/transport.go b/p2p/transport/webtransport/transport.go index ba47a1d64d..294cc6ccfc 100644 --- a/p2p/transport/webtransport/transport.go +++ b/p2p/transport/webtransport/transport.go @@ -4,6 +4,7 @@ import ( "context" "crypto/tls" "fmt" + "io" "sync" "time" @@ -40,6 +41,7 @@ type transport struct { } var _ tpt.Transport = &transport{} +var _ io.Closer = &transport{} func New(key ic.PrivKey) (tpt.Transport, error) { id, err := peer.IDFromPrivateKey(key) @@ -143,3 +145,11 @@ func (t *transport) Protocols() []int { func (t *transport) Proxy() bool { return false } + +func (t *transport) Close() error { + t.listenOnce.Do(func() {}) + if t.certManager != nil { + return t.certManager.Close() + } + return nil +} diff --git a/p2p/transport/webtransport/transport_test.go b/p2p/transport/webtransport/transport_test.go index 84a053be77..d50832144d 100644 --- a/p2p/transport/webtransport/transport_test.go +++ b/p2p/transport/webtransport/transport_test.go @@ -47,6 +47,7 @@ func TestTransport(t *testing.T) { serverID, serverKey := newIdentity(t) tr, err := libp2pwebtransport.New(serverKey) require.NoError(t, err) + defer tr.(io.Closer).Close() ln, err := tr.Listen(ma.StringCast("/ip4/127.0.0.1/udp/0/quic/webtransport")) require.NoError(t, err) @@ -54,6 +55,8 @@ func TestTransport(t *testing.T) { _, clientKey := newIdentity(t) tr2, err := libp2pwebtransport.New(clientKey) require.NoError(t, err) + defer tr2.(io.Closer).Close() + conn, err := tr2.Dial(context.Background(), ln.Multiaddr(), serverID) require.NoError(t, err) str, err := conn.OpenStream(context.Background()) @@ -89,6 +92,8 @@ func TestCanDial(t *testing.T) { _, key := newIdentity(t) tr, err := libp2pwebtransport.New(key) require.NoError(t, err) + defer tr.(io.Closer).Close() + for _, addr := range valid { require.Truef(t, tr.CanDial(addr), "expected to be able to dial %s", addr) } @@ -113,6 +118,8 @@ func TestListenAddrValidity(t *testing.T) { _, key := newIdentity(t) tr, err := libp2pwebtransport.New(key) require.NoError(t, err) + defer tr.(io.Closer).Close() + for _, addr := range valid { ln, err := tr.Listen(addr) require.NoErrorf(t, err, "expected to be able to listen on %s", addr) @@ -128,6 +135,7 @@ func TestListenerAddrs(t *testing.T) { _, key := newIdentity(t) tr, err := libp2pwebtransport.New(key) require.NoError(t, err) + defer tr.(io.Closer).Close() ln1, err := tr.Listen(ma.StringCast("/ip4/127.0.0.1/udp/0/quic/webtransport")) require.NoError(t, err) From 323960e226ab16289a4e943849ea812e35da00bc Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Sat, 21 May 2022 21:14:53 +0200 Subject: [PATCH 13/44] chore: update webtransport-go --- p2p/transport/webtransport/listener.go | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/p2p/transport/webtransport/listener.go b/p2p/transport/webtransport/listener.go index f29256eef5..dd32e181fa 100644 --- a/p2p/transport/webtransport/listener.go +++ b/p2p/transport/webtransport/listener.go @@ -70,11 +70,9 @@ func newListener(laddr ma.Multiaddr, transport tpt.Transport, noise *noise.Trans multiaddr: localMultiaddr, server: webtransport.Server{ H3: http3.Server{ - Server: &http.Server{ - TLSConfig: &tls.Config{GetConfigForClient: func(*tls.ClientHelloInfo) (*tls.Config, error) { - return certManager.GetConfig(), nil - }}, - }, + TLSConfig: &tls.Config{GetConfigForClient: func(*tls.ClientHelloInfo) (*tls.Config, error) { + return certManager.GetConfig(), nil + }}, }, }, } @@ -92,6 +90,8 @@ func newListener(laddr ma.Multiaddr, transport tpt.Transport, noise *noise.Trans } // TODO: handle queue overflow ln.queue <- c + // We need to block until we're done with this WebTransport session. + <-c.Context().Done() }) ln.server.H3.Handler = mux go func() { From 0af48f54863b78b1b45b7af80a76ac8a5a6dbf48 Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Sun, 22 May 2022 11:42:39 +0200 Subject: [PATCH 14/44] run a Noise handshake on the first stream to verify peer identities --- p2p/transport/webtransport/cert_manager.go | 28 +- p2p/transport/webtransport/pb/Makefile | 11 + .../webtransport/pb/webtransport.pb.go | 315 ++++++++++++++++++ .../webtransport/pb/webtransport.proto | 5 + p2p/transport/webtransport/transport.go | 57 +++- p2p/transport/webtransport/transport_test.go | 44 ++- 6 files changed, 441 insertions(+), 19 deletions(-) create mode 100644 p2p/transport/webtransport/pb/Makefile create mode 100644 p2p/transport/webtransport/pb/webtransport.pb.go create mode 100644 p2p/transport/webtransport/pb/webtransport.proto diff --git a/p2p/transport/webtransport/cert_manager.go b/p2p/transport/webtransport/cert_manager.go index 83fc2878cf..f2a0afd04e 100644 --- a/p2p/transport/webtransport/cert_manager.go +++ b/p2p/transport/webtransport/cert_manager.go @@ -1,10 +1,11 @@ package libp2pwebtransport import ( + "bytes" "context" "crypto/sha256" "crypto/tls" - "crypto/x509" + "fmt" "sync" "time" @@ -16,6 +17,7 @@ import ( type certConfig struct { start, end time.Time tlsConf *tls.Config + sha256 [32]byte // cached from the tlsConf } func newCertConfig(start, end time.Time, conf *tls.Config) (*certConfig, error) { @@ -23,6 +25,7 @@ func newCertConfig(start, end time.Time, conf *tls.Config) (*certConfig, error) start: start, end: end, tlsConf: conf, + sha256: sha256.Sum256(conf.Certificates[0].Leaf.Raw), }, nil } @@ -128,13 +131,27 @@ func (m *certManager) AddrComponent() ma.Multiaddr { return m.addrComp } +func (m *certManager) Verify(hashes []multihash.DecodedMultihash) error { + for _, h := range hashes { + if h.Code != multihash.SHA2_256 { + return fmt.Errorf("expected SHA256 hash, got %d", h.Code) + } + if !bytes.Equal(h.Digest, m.currentConfig.sha256[:]) && + (m.nextConfig == nil || !bytes.Equal(h.Digest, m.nextConfig.sha256[:])) && + (m.lastConfig == nil || !bytes.Equal(h.Digest, m.lastConfig.sha256[:])) { + return fmt.Errorf("found unexpected hash: %+x", h.Digest) + } + } + return nil +} + func (m *certManager) cacheAddrComponent() error { - addr, err := m.addrComponentForCert(m.currentConfig.tlsConf.Certificates[0].Leaf) + addr, err := m.addrComponentForCert(m.currentConfig.sha256[:]) if err != nil { return err } if m.nextConfig != nil { - comp, err := m.addrComponentForCert(m.nextConfig.tlsConf.Certificates[0].Leaf) + comp, err := m.addrComponentForCert(m.nextConfig.sha256[:]) if err != nil { return err } @@ -144,9 +161,8 @@ func (m *certManager) cacheAddrComponent() error { return nil } -func (m *certManager) addrComponentForCert(cert *x509.Certificate) (ma.Multiaddr, error) { - hash := sha256.Sum256(cert.Raw) - mh, err := multihash.Encode(hash[:], multihash.SHA2_256) +func (m *certManager) addrComponentForCert(hash []byte) (ma.Multiaddr, error) { + mh, err := multihash.Encode(hash, multihash.SHA2_256) if err != nil { return nil, err } diff --git a/p2p/transport/webtransport/pb/Makefile b/p2p/transport/webtransport/pb/Makefile new file mode 100644 index 0000000000..8af2dd8177 --- /dev/null +++ b/p2p/transport/webtransport/pb/Makefile @@ -0,0 +1,11 @@ +PB = $(wildcard *.proto) +GO = $(PB:.proto=.pb.go) + +all: $(GO) + +%.pb.go: %.proto + protoc --proto_path=$(PWD)/../..:. --gogofaster_out=. $< + +clean: + rm -f *.pb.go + rm -f *.go diff --git a/p2p/transport/webtransport/pb/webtransport.pb.go b/p2p/transport/webtransport/pb/webtransport.pb.go new file mode 100644 index 0000000000..810d0667c9 --- /dev/null +++ b/p2p/transport/webtransport/pb/webtransport.pb.go @@ -0,0 +1,315 @@ +// Code generated by protoc-gen-gogo. DO NOT EDIT. +// source: webtransport.proto + +package webtransport + +import ( + fmt "fmt" + proto "github.com/gogo/protobuf/proto" + io "io" + math "math" + math_bits "math/bits" +) + +// Reference imports to suppress errors if they are not otherwise used. +var _ = proto.Marshal +var _ = fmt.Errorf +var _ = math.Inf + +// This is a compile-time assertion to ensure that this generated file +// is compatible with the proto package it is being compiled against. +// A compilation error at this line likely means your copy of the +// proto package needs to be updated. +const _ = proto.GoGoProtoPackageIsVersion3 // please upgrade the proto package + +type WebTransport struct { + CertHashes [][]byte `protobuf:"bytes,1,rep,name=cert_hashes,json=certHashes" json:"cert_hashes,omitempty"` +} + +func (m *WebTransport) Reset() { *m = WebTransport{} } +func (m *WebTransport) String() string { return proto.CompactTextString(m) } +func (*WebTransport) ProtoMessage() {} +func (*WebTransport) Descriptor() ([]byte, []int) { + return fileDescriptor_db878920ab41a4f3, []int{0} +} +func (m *WebTransport) XXX_Unmarshal(b []byte) error { + return m.Unmarshal(b) +} +func (m *WebTransport) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) { + if deterministic { + return xxx_messageInfo_WebTransport.Marshal(b, m, deterministic) + } else { + b = b[:cap(b)] + n, err := m.MarshalToSizedBuffer(b) + if err != nil { + return nil, err + } + return b[:n], nil + } +} +func (m *WebTransport) XXX_Merge(src proto.Message) { + xxx_messageInfo_WebTransport.Merge(m, src) +} +func (m *WebTransport) XXX_Size() int { + return m.Size() +} +func (m *WebTransport) XXX_DiscardUnknown() { + xxx_messageInfo_WebTransport.DiscardUnknown(m) +} + +var xxx_messageInfo_WebTransport proto.InternalMessageInfo + +func (m *WebTransport) GetCertHashes() [][]byte { + if m != nil { + return m.CertHashes + } + return nil +} + +func init() { + proto.RegisterType((*WebTransport)(nil), "WebTransport") +} + +func init() { proto.RegisterFile("webtransport.proto", fileDescriptor_db878920ab41a4f3) } + +var fileDescriptor_db878920ab41a4f3 = []byte{ + // 109 bytes of a gzipped FileDescriptorProto + 0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xff, 0xe2, 0x12, 0x2a, 0x4f, 0x4d, 0x2a, + 0x29, 0x4a, 0xcc, 0x2b, 0x2e, 0xc8, 0x2f, 0x2a, 0xd1, 0x2b, 0x28, 0xca, 0x2f, 0xc9, 0x57, 0xd2, + 0xe7, 0xe2, 0x09, 0x4f, 0x4d, 0x0a, 0x81, 0x89, 0x0a, 0xc9, 0x73, 0x71, 0x27, 0xa7, 0x16, 0x95, + 0xc4, 0x67, 0x24, 0x16, 0x67, 0xa4, 0x16, 0x4b, 0x30, 0x2a, 0x30, 0x6b, 0xf0, 0x04, 0x71, 0x81, + 0x84, 0x3c, 0xc0, 0x22, 0x4e, 0x12, 0x27, 0x1e, 0xc9, 0x31, 0x5e, 0x78, 0x24, 0xc7, 0xf8, 0xe0, + 0x91, 0x1c, 0xe3, 0x84, 0xc7, 0x72, 0x0c, 0x17, 0x1e, 0xcb, 0x31, 0xdc, 0x78, 0x2c, 0xc7, 0x00, + 0x08, 0x00, 0x00, 0xff, 0xff, 0x50, 0x77, 0xe5, 0x52, 0x5f, 0x00, 0x00, 0x00, +} + +func (m *WebTransport) Marshal() (dAtA []byte, err error) { + size := m.Size() + dAtA = make([]byte, size) + n, err := m.MarshalToSizedBuffer(dAtA[:size]) + if err != nil { + return nil, err + } + return dAtA[:n], nil +} + +func (m *WebTransport) MarshalTo(dAtA []byte) (int, error) { + size := m.Size() + return m.MarshalToSizedBuffer(dAtA[:size]) +} + +func (m *WebTransport) MarshalToSizedBuffer(dAtA []byte) (int, error) { + i := len(dAtA) + _ = i + var l int + _ = l + if len(m.CertHashes) > 0 { + for iNdEx := len(m.CertHashes) - 1; iNdEx >= 0; iNdEx-- { + i -= len(m.CertHashes[iNdEx]) + copy(dAtA[i:], m.CertHashes[iNdEx]) + i = encodeVarintWebtransport(dAtA, i, uint64(len(m.CertHashes[iNdEx]))) + i-- + dAtA[i] = 0xa + } + } + return len(dAtA) - i, nil +} + +func encodeVarintWebtransport(dAtA []byte, offset int, v uint64) int { + offset -= sovWebtransport(v) + base := offset + for v >= 1<<7 { + dAtA[offset] = uint8(v&0x7f | 0x80) + v >>= 7 + offset++ + } + dAtA[offset] = uint8(v) + return base +} +func (m *WebTransport) Size() (n int) { + if m == nil { + return 0 + } + var l int + _ = l + if len(m.CertHashes) > 0 { + for _, b := range m.CertHashes { + l = len(b) + n += 1 + l + sovWebtransport(uint64(l)) + } + } + return n +} + +func sovWebtransport(x uint64) (n int) { + return (math_bits.Len64(x|1) + 6) / 7 +} +func sozWebtransport(x uint64) (n int) { + return sovWebtransport(uint64((x << 1) ^ uint64((int64(x) >> 63)))) +} +func (m *WebTransport) Unmarshal(dAtA []byte) error { + l := len(dAtA) + iNdEx := 0 + for iNdEx < l { + preIndex := iNdEx + var wire uint64 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowWebtransport + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + wire |= uint64(b&0x7F) << shift + if b < 0x80 { + break + } + } + fieldNum := int32(wire >> 3) + wireType := int(wire & 0x7) + if wireType == 4 { + return fmt.Errorf("proto: WebTransport: wiretype end group for non-group") + } + if fieldNum <= 0 { + return fmt.Errorf("proto: WebTransport: illegal tag %d (wire type %d)", fieldNum, wire) + } + switch fieldNum { + case 1: + if wireType != 2 { + return fmt.Errorf("proto: wrong wireType = %d for field CertHashes", wireType) + } + var byteLen int + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowWebtransport + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + byteLen |= int(b&0x7F) << shift + if b < 0x80 { + break + } + } + if byteLen < 0 { + return ErrInvalidLengthWebtransport + } + postIndex := iNdEx + byteLen + if postIndex < 0 { + return ErrInvalidLengthWebtransport + } + if postIndex > l { + return io.ErrUnexpectedEOF + } + m.CertHashes = append(m.CertHashes, make([]byte, postIndex-iNdEx)) + copy(m.CertHashes[len(m.CertHashes)-1], dAtA[iNdEx:postIndex]) + iNdEx = postIndex + default: + iNdEx = preIndex + skippy, err := skipWebtransport(dAtA[iNdEx:]) + if err != nil { + return err + } + if (skippy < 0) || (iNdEx+skippy) < 0 { + return ErrInvalidLengthWebtransport + } + if (iNdEx + skippy) > l { + return io.ErrUnexpectedEOF + } + iNdEx += skippy + } + } + + if iNdEx > l { + return io.ErrUnexpectedEOF + } + return nil +} +func skipWebtransport(dAtA []byte) (n int, err error) { + l := len(dAtA) + iNdEx := 0 + depth := 0 + for iNdEx < l { + var wire uint64 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return 0, ErrIntOverflowWebtransport + } + if iNdEx >= l { + return 0, io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + wire |= (uint64(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + wireType := int(wire & 0x7) + switch wireType { + case 0: + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return 0, ErrIntOverflowWebtransport + } + if iNdEx >= l { + return 0, io.ErrUnexpectedEOF + } + iNdEx++ + if dAtA[iNdEx-1] < 0x80 { + break + } + } + case 1: + iNdEx += 8 + case 2: + var length int + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return 0, ErrIntOverflowWebtransport + } + if iNdEx >= l { + return 0, io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + length |= (int(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + if length < 0 { + return 0, ErrInvalidLengthWebtransport + } + iNdEx += length + case 3: + depth++ + case 4: + if depth == 0 { + return 0, ErrUnexpectedEndOfGroupWebtransport + } + depth-- + case 5: + iNdEx += 4 + default: + return 0, fmt.Errorf("proto: illegal wireType %d", wireType) + } + if iNdEx < 0 { + return 0, ErrInvalidLengthWebtransport + } + if depth == 0 { + return iNdEx, nil + } + } + return 0, io.ErrUnexpectedEOF +} + +var ( + ErrInvalidLengthWebtransport = fmt.Errorf("proto: negative length found during unmarshaling") + ErrIntOverflowWebtransport = fmt.Errorf("proto: integer overflow") + ErrUnexpectedEndOfGroupWebtransport = fmt.Errorf("proto: unexpected end of group") +) diff --git a/p2p/transport/webtransport/pb/webtransport.proto b/p2p/transport/webtransport/pb/webtransport.proto new file mode 100644 index 0000000000..a9129ce6c6 --- /dev/null +++ b/p2p/transport/webtransport/pb/webtransport.proto @@ -0,0 +1,5 @@ +syntax = "proto2"; + +message WebTransport { + repeated bytes cert_hashes = 1; +} diff --git a/p2p/transport/webtransport/transport.go b/p2p/transport/webtransport/transport.go index 294cc6ccfc..8b35d88b5f 100644 --- a/p2p/transport/webtransport/transport.go +++ b/p2p/transport/webtransport/transport.go @@ -8,17 +8,21 @@ import ( "sync" "time" + pb "github.com/marten-seemann/go-libp2p-webtransport/pb" + ic "github.com/libp2p/go-libp2p-core/crypto" "github.com/libp2p/go-libp2p-core/peer" tpt "github.com/libp2p/go-libp2p-core/transport" + noise "github.com/libp2p/go-libp2p-noise" - logging "github.com/ipfs/go-log/v2" - "github.com/marten-seemann/webtransport-go" ma "github.com/multiformats/go-multiaddr" manet "github.com/multiformats/go-multiaddr/net" "github.com/multiformats/go-multibase" "github.com/multiformats/go-multihash" + + logging "github.com/ipfs/go-log/v2" + "github.com/marten-seemann/webtransport-go" ) var log = logging.Logger("webtransport") @@ -48,18 +52,19 @@ func New(key ic.PrivKey) (tpt.Transport, error) { if err != nil { return nil, err } - noise, err := noise.New(key) - if err != nil { - return nil, err - } - return &transport{ + t := &transport{ pid: id, privKey: key, dialer: webtransport.Dialer{ TLSClientConf: &tls.Config{InsecureSkipVerify: true}, // TODO: verify certificate, }, - noise: noise, - }, nil + } + noise, err := noise.New(key, noise.WithEarlyDataHandler(t.checkEarlyData)) + if err != nil { + return nil, err + } + t.noise = noise + return t, nil } func (t *transport) Dial(ctx context.Context, raddr ma.Multiaddr, p peer.ID) (tpt.CapableConn, error) { @@ -99,15 +104,43 @@ func (t *transport) Dial(ctx context.Context, raddr ma.Multiaddr, p peer.ID) (tp if err != nil { return nil, err } - // TODO: use early data and verify the cert hash - _ = certHashes - sconn, err := t.noise.SecureOutbound(ctx, &webtransportStream{Stream: str, wconn: wconn}, p) + + // Now run a Noise handshake (using early data) and verify the cert hash. + msg := pb.WebTransport{CertHashes: make([][]byte, 0, len(certHashes))} + for _, certHash := range certHashes { + h, err := multihash.Encode(certHash.Digest, certHash.Code) + if err != nil { + return nil, fmt.Errorf("failed to encode certificate hash: %w", err) + } + msg.CertHashes = append(msg.CertHashes, h) + } + msgBytes, err := msg.Marshal() + if err != nil { + return nil, fmt.Errorf("failed to marshal WebTransport protobuf: %w", err) + } + sconn, err := t.noise.SecureOutboundWithEarlyData(ctx, &webtransportStream{Stream: str, wconn: wconn}, p, msgBytes) if err != nil { return nil, err } return newConn(t, wconn, t.privKey, sconn.RemotePublicKey()) } +func (t *transport) checkEarlyData(b []byte) error { + var msg pb.WebTransport + if err := msg.Unmarshal(b); err != nil { + return fmt.Errorf("failed to unmarshal early data protobuf: %w", err) + } + hashes := make([]multihash.DecodedMultihash, 0, len(msg.CertHashes)) + for _, h := range msg.CertHashes { + dh, err := multihash.Decode(h) + if err != nil { + return fmt.Errorf("failed to decode hash: %w", err) + } + hashes = append(hashes, *dh) + } + return t.certManager.Verify(hashes) +} + func (t *transport) CanDial(addr ma.Multiaddr) bool { var numHashes int ma.ForEach(addr, func(c ma.Component) bool { diff --git a/p2p/transport/webtransport/transport_test.go b/p2p/transport/webtransport/transport_test.go index d50832144d..a3fcb34b0b 100644 --- a/p2p/transport/webtransport/transport_test.go +++ b/p2p/transport/webtransport/transport_test.go @@ -3,16 +3,20 @@ package libp2pwebtransport_test import ( "context" "crypto/rand" + "crypto/sha256" "fmt" "io" "testing" + libp2pwebtransport "github.com/marten-seemann/go-libp2p-webtransport" + ic "github.com/libp2p/go-libp2p-core/crypto" "github.com/libp2p/go-libp2p-core/peer" - libp2pwebtransport "github.com/marten-seemann/go-libp2p-webtransport" ma "github.com/multiformats/go-multiaddr" "github.com/multiformats/go-multibase" + "github.com/multiformats/go-multihash" + "github.com/stretchr/testify/require" ) @@ -50,6 +54,7 @@ func TestTransport(t *testing.T) { defer tr.(io.Closer).Close() ln, err := tr.Listen(ma.StringCast("/ip4/127.0.0.1/udp/0/quic/webtransport")) require.NoError(t, err) + defer ln.Close() go func() { _, clientKey := newIdentity(t) @@ -75,6 +80,43 @@ func TestTransport(t *testing.T) { require.Equal(t, "foobar", string(data)) } +func TestHashVerification(t *testing.T) { + serverID, serverKey := newIdentity(t) + tr, err := libp2pwebtransport.New(serverKey) + require.NoError(t, err) + defer tr.(io.Closer).Close() + ln, err := tr.Listen(ma.StringCast("/ip4/127.0.0.1/udp/0/quic/webtransport")) + require.NoError(t, err) + done := make(chan struct{}) + go func() { + defer close(done) + _, err := ln.Accept() + require.Error(t, err) + }() + + // replace the certificate hash in the multiaddr with a fake hash + addr, _ := ma.SplitLast(ln.Multiaddr()) + h := sha256.Sum256([]byte("foobar")) + mh, err := multihash.Encode(h[:], multihash.SHA2_256) + require.NoError(t, err) + certStr, err := multibase.Encode(multibase.Base58BTC, mh) + require.NoError(t, err) + comp, err := ma.NewComponent(ma.ProtocolWithCode(ma.P_CERTHASH).Name, certStr) + require.NoError(t, err) + addr = addr.Encapsulate(comp) + + _, clientKey := newIdentity(t) + tr2, err := libp2pwebtransport.New(clientKey) + require.NoError(t, err) + defer tr2.(io.Closer).Close() + + _, err = tr2.Dial(context.Background(), addr, serverID) + require.Error(t, err) + + require.NoError(t, ln.Close()) + <-done +} + func TestCanDial(t *testing.T) { valid := []ma.Multiaddr{ ma.StringCast("/ip4/127.0.0.1/udp/1234/quic/webtransport/certhash/" + randomMultihash(t)), From 52ae9b79b98986259a6efddc780ea7c3649b8c07 Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Sun, 22 May 2022 12:15:43 +0200 Subject: [PATCH 15/44] simplify the listener.Accept logic --- p2p/transport/webtransport/listener.go | 49 +++++++++----------------- 1 file changed, 17 insertions(+), 32 deletions(-) diff --git a/p2p/transport/webtransport/listener.go b/p2p/transport/webtransport/listener.go index dd32e181fa..993354f965 100644 --- a/p2p/transport/webtransport/listener.go +++ b/p2p/transport/webtransport/listener.go @@ -38,7 +38,7 @@ type listener struct { addr net.Addr multiaddr ma.Multiaddr - queue chan *webtransport.Conn + queue chan tpt.CapableConn } var _ tpt.Listener = &listener{} @@ -64,7 +64,7 @@ func newListener(laddr ma.Multiaddr, transport tpt.Transport, noise *noise.Trans transport: transport, noise: noise, certManager: certManager, - queue: make(chan *webtransport.Conn, queueLen), + queue: make(chan tpt.CapableConn, queueLen), serverClosed: make(chan struct{}), addr: udpConn.LocalAddr(), multiaddr: localMultiaddr, @@ -88,8 +88,16 @@ func newListener(laddr ma.Multiaddr, transport tpt.Transport, noise *noise.Trans w.WriteHeader(500) return } - // TODO: handle queue overflow - ln.queue <- c + ctx, cancel := context.WithTimeout(ln.ctx, handshakeTimeout) + conn, err := ln.handshake(ctx, c) + if err != nil { + cancel() + log.Debugw("handshake failed", "error", err) + c.Close() + return + } + cancel() + ln.queue <- conn // We need to block until we're done with this WebTransport session. <-c.Context().Done() }) @@ -106,34 +114,11 @@ func newListener(laddr ma.Multiaddr, transport tpt.Transport, noise *noise.Trans } func (l *listener) Accept() (tpt.CapableConn, error) { - queue := make(chan tpt.CapableConn, queueLen) - for { - select { - case <-l.ctx.Done(): - return nil, errClosed - default: - } - - var c *webtransport.Conn - select { - case c = <-l.queue: - go func(c *webtransport.Conn) { - ctx, cancel := context.WithTimeout(l.ctx, handshakeTimeout) - defer cancel() - conn, err := l.handshake(ctx, c) - if err != nil { - log.Debugw("handshake failed", "error", err) - c.Close() - return - } - // TODO: handle queue overflow - queue <- conn - }(c) - case conn := <-queue: - return conn, nil - case <-l.ctx.Done(): - return nil, errClosed - } + select { + case <-l.ctx.Done(): + return nil, errClosed + case c := <-l.queue: + return c, nil } } From 91e97305e35114575887aee717f80acea0b041d4 Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Sun, 22 May 2022 16:47:09 +0200 Subject: [PATCH 16/44] don't close the HTTP response body when dialing --- p2p/transport/webtransport/transport.go | 1 - 1 file changed, 1 deletion(-) diff --git a/p2p/transport/webtransport/transport.go b/p2p/transport/webtransport/transport.go index 8b35d88b5f..2596ba4dcd 100644 --- a/p2p/transport/webtransport/transport.go +++ b/p2p/transport/webtransport/transport.go @@ -96,7 +96,6 @@ func (t *transport) Dial(ctx context.Context, raddr ma.Multiaddr, p peer.ID) (tp if err != nil { return nil, err } - defer rsp.Body.Close() if rsp.StatusCode < 200 || rsp.StatusCode > 299 { return nil, fmt.Errorf("invalid response status code: %d", rsp.StatusCode) } From d8209e3532707c1dc823efe6b7537651a63c979c Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Sun, 22 May 2022 16:47:35 +0200 Subject: [PATCH 17/44] add more tests for failed handshakes --- p2p/transport/webtransport/transport_test.go | 29 +++++++++++++------- 1 file changed, 19 insertions(+), 10 deletions(-) diff --git a/p2p/transport/webtransport/transport_test.go b/p2p/transport/webtransport/transport_test.go index a3fcb34b0b..58b4c0fa0f 100644 --- a/p2p/transport/webtransport/transport_test.go +++ b/p2p/transport/webtransport/transport_test.go @@ -94,24 +94,33 @@ func TestHashVerification(t *testing.T) { require.Error(t, err) }() - // replace the certificate hash in the multiaddr with a fake hash - addr, _ := ma.SplitLast(ln.Multiaddr()) + _, clientKey := newIdentity(t) + tr2, err := libp2pwebtransport.New(clientKey) + require.NoError(t, err) + defer tr2.(io.Closer).Close() + + // create a hash component using the SHA256 of foobar h := sha256.Sum256([]byte("foobar")) mh, err := multihash.Encode(h[:], multihash.SHA2_256) require.NoError(t, err) certStr, err := multibase.Encode(multibase.Base58BTC, mh) require.NoError(t, err) - comp, err := ma.NewComponent(ma.ProtocolWithCode(ma.P_CERTHASH).Name, certStr) + foobarHash, err := ma.NewComponent(ma.ProtocolWithCode(ma.P_CERTHASH).Name, certStr) require.NoError(t, err) - addr = addr.Encapsulate(comp) - _, clientKey := newIdentity(t) - tr2, err := libp2pwebtransport.New(clientKey) - require.NoError(t, err) - defer tr2.(io.Closer).Close() + t.Run("fails using only a wrong hash", func(t *testing.T) { + // replace the certificate hash in the multiaddr with a fake hash + addr, _ := ma.SplitLast(ln.Multiaddr()) + addr = addr.Encapsulate(foobarHash) - _, err = tr2.Dial(context.Background(), addr, serverID) - require.Error(t, err) + _, err := tr2.Dial(context.Background(), addr, serverID) + require.Error(t, err) + }) + + t.Run("fails when adding a wrong hash", func(t *testing.T) { + _, err := tr2.Dial(context.Background(), ln.Multiaddr().Encapsulate(foobarHash), serverID) + require.Error(t, err) + }) require.NoError(t, ln.Close()) <-done From 47be4122b7c9e8a5bf4c20abde21fec729cbb2fa Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Sun, 22 May 2022 18:26:18 +0200 Subject: [PATCH 18/44] return local and remote multiaddrs from conn --- p2p/transport/webtransport/conn.go | 32 +++++++++++--------- p2p/transport/webtransport/transport_test.go | 11 +++++++ 2 files changed, 29 insertions(+), 14 deletions(-) diff --git a/p2p/transport/webtransport/conn.go b/p2p/transport/webtransport/conn.go index c1b5742a47..e2ab9cd50b 100644 --- a/p2p/transport/webtransport/conn.go +++ b/p2p/transport/webtransport/conn.go @@ -2,6 +2,7 @@ package libp2pwebtransport import ( "context" + "fmt" ic "github.com/libp2p/go-libp2p-core/crypto" "github.com/libp2p/go-libp2p-core/network" @@ -17,6 +18,7 @@ type conn struct { wconn *webtransport.Conn localPeer, remotePeer peer.ID + local, remote ma.Multiaddr privKey ic.PrivKey remotePubKey ic.PubKey } @@ -30,6 +32,14 @@ func newConn(tr tpt.Transport, wconn *webtransport.Conn, privKey ic.PrivKey, rem if err != nil { return nil, err } + local, err := toWebtransportMultiaddr(wconn.LocalAddr()) + if err != nil { + return nil, fmt.Errorf("error determiniting local addr: %w", err) + } + remote, err := toWebtransportMultiaddr(wconn.RemoteAddr()) + if err != nil { + return nil, fmt.Errorf("error determiniting remote addr: %w", err) + } return &conn{ transport: tr, wconn: wconn, @@ -37,6 +47,8 @@ func newConn(tr tpt.Transport, wconn *webtransport.Conn, privKey ic.PrivKey, rem localPeer: localPeer, remotePeer: remotePeer, remotePubKey: remotePubKey, + local: local, + remote: remote, }, nil } @@ -60,20 +72,12 @@ func (c *conn) AcceptStream() (network.MuxedStream, error) { return &stream{str}, err } -func (c *conn) LocalPeer() peer.ID { return c.localPeer } -func (c *conn) LocalPrivateKey() ic.PrivKey { return c.privKey } -func (c *conn) RemotePeer() peer.ID { return c.remotePeer } -func (c *conn) RemotePublicKey() ic.PubKey { return c.remotePubKey } - -func (c *conn) LocalMultiaddr() ma.Multiaddr { - // TODO implement me - panic("implement me") -} - -func (c *conn) RemoteMultiaddr() ma.Multiaddr { - // TODO implement me - panic("implement me") -} +func (c *conn) LocalPeer() peer.ID { return c.localPeer } +func (c *conn) LocalPrivateKey() ic.PrivKey { return c.privKey } +func (c *conn) RemotePeer() peer.ID { return c.remotePeer } +func (c *conn) RemotePublicKey() ic.PubKey { return c.remotePubKey } +func (c *conn) LocalMultiaddr() ma.Multiaddr { return c.local } +func (c *conn) RemoteMultiaddr() ma.Multiaddr { return c.remote } func (c *conn) Scope() network.ConnScope { // TODO implement me diff --git a/p2p/transport/webtransport/transport_test.go b/p2p/transport/webtransport/transport_test.go index 58b4c0fa0f..4d640eb5d3 100644 --- a/p2p/transport/webtransport/transport_test.go +++ b/p2p/transport/webtransport/transport_test.go @@ -6,6 +6,7 @@ import ( "crypto/sha256" "fmt" "io" + "strings" "testing" libp2pwebtransport "github.com/marten-seemann/go-libp2p-webtransport" @@ -14,6 +15,7 @@ import ( "github.com/libp2p/go-libp2p-core/peer" ma "github.com/multiformats/go-multiaddr" + manet "github.com/multiformats/go-multiaddr/net" "github.com/multiformats/go-multibase" "github.com/multiformats/go-multihash" @@ -56,6 +58,7 @@ func TestTransport(t *testing.T) { require.NoError(t, err) defer ln.Close() + addrChan := make(chan ma.Multiaddr) go func() { _, clientKey := newIdentity(t) tr2, err := libp2pwebtransport.New(clientKey) @@ -69,6 +72,13 @@ func TestTransport(t *testing.T) { _, err = str.Write([]byte("foobar")) require.NoError(t, err) require.NoError(t, str.Close()) + + // check RemoteMultiaddr + _, addr, err := manet.DialArgs(ln.Multiaddr()) + require.NoError(t, err) + port := strings.Split(addr, ":")[1] + require.Equal(t, ma.StringCast(fmt.Sprintf("/ip4/127.0.0.1/udp/%s/quic/webtransport", port)), conn.RemoteMultiaddr()) + addrChan <- conn.RemoteMultiaddr() }() conn, err := ln.Accept() @@ -78,6 +88,7 @@ func TestTransport(t *testing.T) { data, err := io.ReadAll(str) require.NoError(t, err) require.Equal(t, "foobar", string(data)) + require.Equal(t, <-addrChan, conn.LocalMultiaddr()) } func TestHashVerification(t *testing.T) { From a0eec0f0e000d5578c04ea05d70d3cfd3d8a6b1a Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Sat, 2 Jul 2022 11:09:28 +0200 Subject: [PATCH 19/44] update webtransport-go, rename webtransport.Conn to Session --- p2p/transport/webtransport/conn.go | 16 ++++++++-------- p2p/transport/webtransport/listener.go | 8 ++++---- p2p/transport/webtransport/stream.go | 6 +++--- p2p/transport/webtransport/transport.go | 12 +++++++----- 4 files changed, 22 insertions(+), 20 deletions(-) diff --git a/p2p/transport/webtransport/conn.go b/p2p/transport/webtransport/conn.go index e2ab9cd50b..90ae154b8c 100644 --- a/p2p/transport/webtransport/conn.go +++ b/p2p/transport/webtransport/conn.go @@ -15,7 +15,7 @@ import ( type conn struct { transport tpt.Transport - wconn *webtransport.Conn + wsess *webtransport.Session localPeer, remotePeer peer.ID local, remote ma.Multiaddr @@ -23,7 +23,7 @@ type conn struct { remotePubKey ic.PubKey } -func newConn(tr tpt.Transport, wconn *webtransport.Conn, privKey ic.PrivKey, remotePubKey ic.PubKey) (*conn, error) { +func newConn(tr tpt.Transport, wsess *webtransport.Session, privKey ic.PrivKey, remotePubKey ic.PubKey) (*conn, error) { localPeer, err := peer.IDFromPrivateKey(privKey) if err != nil { return nil, err @@ -32,17 +32,17 @@ func newConn(tr tpt.Transport, wconn *webtransport.Conn, privKey ic.PrivKey, rem if err != nil { return nil, err } - local, err := toWebtransportMultiaddr(wconn.LocalAddr()) + local, err := toWebtransportMultiaddr(wsess.LocalAddr()) if err != nil { return nil, fmt.Errorf("error determiniting local addr: %w", err) } - remote, err := toWebtransportMultiaddr(wconn.RemoteAddr()) + remote, err := toWebtransportMultiaddr(wsess.RemoteAddr()) if err != nil { return nil, fmt.Errorf("error determiniting remote addr: %w", err) } return &conn{ transport: tr, - wconn: wconn, + wsess: wsess, privKey: privKey, localPeer: localPeer, remotePeer: remotePeer, @@ -55,7 +55,7 @@ func newConn(tr tpt.Transport, wconn *webtransport.Conn, privKey ic.PrivKey, rem var _ tpt.CapableConn = &conn{} func (c *conn) Close() error { - return c.wconn.Close() + return c.wsess.Close() } func (c *conn) IsClosed() bool { @@ -63,12 +63,12 @@ func (c *conn) IsClosed() bool { } func (c *conn) OpenStream(ctx context.Context) (network.MuxedStream, error) { - str, err := c.wconn.OpenStreamSync(ctx) + str, err := c.wsess.OpenStreamSync(ctx) return &stream{str}, err } func (c *conn) AcceptStream() (network.MuxedStream, error) { - str, err := c.wconn.AcceptStream(context.Background()) + str, err := c.wsess.AcceptStream(context.Background()) return &stream{str}, err } diff --git a/p2p/transport/webtransport/listener.go b/p2p/transport/webtransport/listener.go index 993354f965..6789a74542 100644 --- a/p2p/transport/webtransport/listener.go +++ b/p2p/transport/webtransport/listener.go @@ -122,16 +122,16 @@ func (l *listener) Accept() (tpt.CapableConn, error) { } } -func (l *listener) handshake(ctx context.Context, c *webtransport.Conn) (tpt.CapableConn, error) { - str, err := c.AcceptStream(ctx) +func (l *listener) handshake(ctx context.Context, sess *webtransport.Session) (tpt.CapableConn, error) { + str, err := sess.AcceptStream(ctx) if err != nil { return nil, err } - conn, err := l.noise.SecureInbound(ctx, &webtransportStream{Stream: str, wconn: c}, "") + conn, err := l.noise.SecureInbound(ctx, &webtransportStream{Stream: str, wsess: sess}, "") if err != nil { return nil, err } - return newConn(l.transport, c, conn.LocalPrivateKey(), conn.RemotePublicKey()) + return newConn(l.transport, sess, conn.LocalPrivateKey(), conn.RemotePublicKey()) } func (l *listener) Addr() net.Addr { diff --git a/p2p/transport/webtransport/stream.go b/p2p/transport/webtransport/stream.go index fa34242b6c..6aa58cb8d8 100644 --- a/p2p/transport/webtransport/stream.go +++ b/p2p/transport/webtransport/stream.go @@ -15,17 +15,17 @@ const ( type webtransportStream struct { webtransport.Stream - wconn *webtransport.Conn + wsess *webtransport.Session } var _ net.Conn = &webtransportStream{} func (s *webtransportStream) LocalAddr() net.Addr { - return s.wconn.LocalAddr() + return s.wsess.LocalAddr() } func (s *webtransportStream) RemoteAddr() net.Addr { - return s.wconn.RemoteAddr() + return s.wsess.RemoteAddr() } type stream struct { diff --git a/p2p/transport/webtransport/transport.go b/p2p/transport/webtransport/transport.go index 2596ba4dcd..29eb784788 100644 --- a/p2p/transport/webtransport/transport.go +++ b/p2p/transport/webtransport/transport.go @@ -16,13 +16,13 @@ import ( noise "github.com/libp2p/go-libp2p-noise" + logging "github.com/ipfs/go-log/v2" + "github.com/lucas-clemente/quic-go/http3" + "github.com/marten-seemann/webtransport-go" ma "github.com/multiformats/go-multiaddr" manet "github.com/multiformats/go-multiaddr/net" "github.com/multiformats/go-multibase" "github.com/multiformats/go-multihash" - - logging "github.com/ipfs/go-log/v2" - "github.com/marten-seemann/webtransport-go" ) var log = logging.Logger("webtransport") @@ -56,7 +56,9 @@ func New(key ic.PrivKey) (tpt.Transport, error) { pid: id, privKey: key, dialer: webtransport.Dialer{ - TLSClientConf: &tls.Config{InsecureSkipVerify: true}, // TODO: verify certificate, + RoundTripper: &http3.RoundTripper{ + TLSClientConfig: &tls.Config{InsecureSkipVerify: true}, // TODO: verify certificate, + }, }, } noise, err := noise.New(key, noise.WithEarlyDataHandler(t.checkEarlyData)) @@ -117,7 +119,7 @@ func (t *transport) Dial(ctx context.Context, raddr ma.Multiaddr, p peer.ID) (tp if err != nil { return nil, fmt.Errorf("failed to marshal WebTransport protobuf: %w", err) } - sconn, err := t.noise.SecureOutboundWithEarlyData(ctx, &webtransportStream{Stream: str, wconn: wconn}, p, msgBytes) + sconn, err := t.noise.SecureOutboundWithEarlyData(ctx, &webtransportStream{Stream: str, wsess: wconn}, p, msgBytes) if err != nil { return nil, err } From d155f115144d34d99bc2ca1bf727ffd1ae5a06c5 Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Sat, 9 Jul 2022 12:05:51 +0000 Subject: [PATCH 20/44] tighten multiaddr conversion logic --- p2p/transport/webtransport/multiaddr.go | 8 ++++++-- p2p/transport/webtransport/multiaddr_test.go | 21 ++++++++++++++++++++ 2 files changed, 27 insertions(+), 2 deletions(-) create mode 100644 p2p/transport/webtransport/multiaddr_test.go diff --git a/p2p/transport/webtransport/multiaddr.go b/p2p/transport/webtransport/multiaddr.go index 5560d90f68..9d79991938 100644 --- a/p2p/transport/webtransport/multiaddr.go +++ b/p2p/transport/webtransport/multiaddr.go @@ -1,6 +1,7 @@ package libp2pwebtransport import ( + "errors" "net" ma "github.com/multiformats/go-multiaddr" @@ -13,9 +14,12 @@ var webtransportMA = ma.StringCast("/quic/webtransport") var webtransportMatcher = mafmt.And(mafmt.IP, mafmt.Base(ma.P_UDP), mafmt.Base(ma.P_QUIC), mafmt.Base(ma.P_WEBTRANSPORT)) func toWebtransportMultiaddr(na net.Addr) (ma.Multiaddr, error) { - udpMA, err := manet.FromNetAddr(na) + addr, err := manet.FromNetAddr(na) if err != nil { return nil, err } - return udpMA.Encapsulate(webtransportMA), nil + if _, err := addr.ValueForProtocol(ma.P_UDP); err != nil { + return nil, errors.New("not a UDP address") + } + return addr.Encapsulate(webtransportMA), nil } diff --git a/p2p/transport/webtransport/multiaddr_test.go b/p2p/transport/webtransport/multiaddr_test.go new file mode 100644 index 0000000000..2098babbd0 --- /dev/null +++ b/p2p/transport/webtransport/multiaddr_test.go @@ -0,0 +1,21 @@ +package libp2pwebtransport + +import ( + "net" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestWebtransportMultiaddr(t *testing.T) { + t.Run("valid", func(t *testing.T) { + addr, err := toWebtransportMultiaddr(&net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337}) + require.NoError(t, err) + require.Equal(t, "/ip4/127.0.0.1/udp/1337/quic/webtransport", addr.String()) + }) + + t.Run("invalid", func(t *testing.T) { + _, err := toWebtransportMultiaddr(&net.TCPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337}) + require.EqualError(t, err, "not a UDP address") + }) +} From 106cbc3810f725241ca27bab54ac2c1fbe6fcd7c Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Sat, 9 Jul 2022 14:30:41 +0000 Subject: [PATCH 21/44] return the listener's HTTP handler as soon as the handshake completes With the most recent webtransport-go, it's not necessary to block any more. --- p2p/transport/webtransport/listener.go | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/p2p/transport/webtransport/listener.go b/p2p/transport/webtransport/listener.go index 6789a74542..91767dac1c 100644 --- a/p2p/transport/webtransport/listener.go +++ b/p2p/transport/webtransport/listener.go @@ -97,9 +97,8 @@ func newListener(laddr ma.Multiaddr, transport tpt.Transport, noise *noise.Trans return } cancel() + // TODO: think about what happens when this channel fills up ln.queue <- conn - // We need to block until we're done with this WebTransport session. - <-c.Context().Done() }) ln.server.H3.Handler = mux go func() { From ed5a2f58da861540ee5c31de955bbf8e064e81d7 Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Sat, 9 Jul 2022 11:30:49 +0000 Subject: [PATCH 22/44] use the resource manager when dialing --- p2p/transport/webtransport/transport.go | 26 ++++++++++- p2p/transport/webtransport/transport_test.go | 45 ++++++++++++++++---- 2 files changed, 61 insertions(+), 10 deletions(-) diff --git a/p2p/transport/webtransport/transport.go b/p2p/transport/webtransport/transport.go index 29eb784788..52b4e554a2 100644 --- a/p2p/transport/webtransport/transport.go +++ b/p2p/transport/webtransport/transport.go @@ -11,6 +11,7 @@ import ( pb "github.com/marten-seemann/go-libp2p-webtransport/pb" ic "github.com/libp2p/go-libp2p-core/crypto" + "github.com/libp2p/go-libp2p-core/network" "github.com/libp2p/go-libp2p-core/peer" tpt "github.com/libp2p/go-libp2p-core/transport" @@ -37,6 +38,8 @@ type transport struct { dialer webtransport.Dialer + rcmgr network.ResourceManager + listenOnce sync.Once listenOnceErr error certManager *certManager @@ -47,7 +50,7 @@ type transport struct { var _ tpt.Transport = &transport{} var _ io.Closer = &transport{} -func New(key ic.PrivKey) (tpt.Transport, error) { +func New(key ic.PrivKey, rcmgr network.ResourceManager) (tpt.Transport, error) { id, err := peer.IDFromPrivateKey(key) if err != nil { return nil, err @@ -55,6 +58,7 @@ func New(key ic.PrivKey) (tpt.Transport, error) { t := &transport{ pid: id, privKey: key, + rcmgr: rcmgr, dialer: webtransport.Dialer{ RoundTripper: &http3.RoundTripper{ TLSClientConfig: &tls.Config{InsecureSkipVerify: true}, // TODO: verify certificate, @@ -70,6 +74,26 @@ func New(key ic.PrivKey) (tpt.Transport, error) { } func (t *transport) Dial(ctx context.Context, raddr ma.Multiaddr, p peer.ID) (tpt.CapableConn, error) { + scope, err := t.rcmgr.OpenConnection(network.DirOutbound, false, raddr) + if err != nil { + log.Debugw("resource manager blocked outgoing connection", "peer", p, "addr", raddr, "error", err) + return nil, err + } + if err := scope.SetPeer(p); err != nil { + log.Debugw("resource manager blocked outgoing connection for peer", "peer", p, "addr", raddr, "error", err) + scope.Done() + return nil, err + } + + conn, err := t.dial(ctx, raddr, p) + if err != nil { + scope.Done() + return nil, err + } + return conn, nil +} + +func (t *transport) dial(ctx context.Context, raddr ma.Multiaddr, p peer.ID) (tpt.CapableConn, error) { _, addr, err := manet.DialArgs(raddr) if err != nil { return nil, err diff --git a/p2p/transport/webtransport/transport_test.go b/p2p/transport/webtransport/transport_test.go index 4d640eb5d3..f27a4984e4 100644 --- a/p2p/transport/webtransport/transport_test.go +++ b/p2p/transport/webtransport/transport_test.go @@ -4,16 +4,20 @@ import ( "context" "crypto/rand" "crypto/sha256" + "errors" "fmt" "io" - "strings" + "net" "testing" libp2pwebtransport "github.com/marten-seemann/go-libp2p-webtransport" ic "github.com/libp2p/go-libp2p-core/crypto" + "github.com/libp2p/go-libp2p-core/network" "github.com/libp2p/go-libp2p-core/peer" + "github.com/golang/mock/gomock" + mocknetwork "github.com/libp2p/go-libp2p-testing/mocks/network" ma "github.com/multiformats/go-multiaddr" manet "github.com/multiformats/go-multiaddr/net" "github.com/multiformats/go-multibase" @@ -51,7 +55,7 @@ func extractCertHashes(t *testing.T, addr ma.Multiaddr) []string { func TestTransport(t *testing.T) { serverID, serverKey := newIdentity(t) - tr, err := libp2pwebtransport.New(serverKey) + tr, err := libp2pwebtransport.New(serverKey, network.NullResourceManager) require.NoError(t, err) defer tr.(io.Closer).Close() ln, err := tr.Listen(ma.StringCast("/ip4/127.0.0.1/udp/0/quic/webtransport")) @@ -61,7 +65,7 @@ func TestTransport(t *testing.T) { addrChan := make(chan ma.Multiaddr) go func() { _, clientKey := newIdentity(t) - tr2, err := libp2pwebtransport.New(clientKey) + tr2, err := libp2pwebtransport.New(clientKey, network.NullResourceManager) require.NoError(t, err) defer tr2.(io.Closer).Close() @@ -76,7 +80,8 @@ func TestTransport(t *testing.T) { // check RemoteMultiaddr _, addr, err := manet.DialArgs(ln.Multiaddr()) require.NoError(t, err) - port := strings.Split(addr, ":")[1] + _, port, err := net.SplitHostPort(addr) + require.NoError(t, err) require.Equal(t, ma.StringCast(fmt.Sprintf("/ip4/127.0.0.1/udp/%s/quic/webtransport", port)), conn.RemoteMultiaddr()) addrChan <- conn.RemoteMultiaddr() }() @@ -93,7 +98,7 @@ func TestTransport(t *testing.T) { func TestHashVerification(t *testing.T) { serverID, serverKey := newIdentity(t) - tr, err := libp2pwebtransport.New(serverKey) + tr, err := libp2pwebtransport.New(serverKey, network.NullResourceManager) require.NoError(t, err) defer tr.(io.Closer).Close() ln, err := tr.Listen(ma.StringCast("/ip4/127.0.0.1/udp/0/quic/webtransport")) @@ -106,7 +111,7 @@ func TestHashVerification(t *testing.T) { }() _, clientKey := newIdentity(t) - tr2, err := libp2pwebtransport.New(clientKey) + tr2, err := libp2pwebtransport.New(clientKey, network.NullResourceManager) require.NoError(t, err) defer tr2.(io.Closer).Close() @@ -152,7 +157,7 @@ func TestCanDial(t *testing.T) { } _, key := newIdentity(t) - tr, err := libp2pwebtransport.New(key) + tr, err := libp2pwebtransport.New(key, network.NullResourceManager) require.NoError(t, err) defer tr.(io.Closer).Close() @@ -178,7 +183,7 @@ func TestListenAddrValidity(t *testing.T) { } _, key := newIdentity(t) - tr, err := libp2pwebtransport.New(key) + tr, err := libp2pwebtransport.New(key, network.NullResourceManager) require.NoError(t, err) defer tr.(io.Closer).Close() @@ -195,7 +200,7 @@ func TestListenAddrValidity(t *testing.T) { func TestListenerAddrs(t *testing.T) { _, key := newIdentity(t) - tr, err := libp2pwebtransport.New(key) + tr, err := libp2pwebtransport.New(key, network.NullResourceManager) require.NoError(t, err) defer tr.(io.Closer).Close() @@ -208,3 +213,25 @@ func TestListenerAddrs(t *testing.T) { hashes2 := extractCertHashes(t, ln2.Multiaddr()) require.Equal(t, hashes1, hashes2) } + +func TestResourceManagerDialing(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + rcmgr := mocknetwork.NewMockResourceManager(ctrl) + + addr := ma.StringCast("/ip4/9.8.7.6/udp/1234/quic/webtransport") + p := peer.ID("foobar") + + _, key := newIdentity(t) + tr, err := libp2pwebtransport.New(key, rcmgr) + require.NoError(t, err) + defer tr.(io.Closer).Close() + + scope := mocknetwork.NewMockConnManagementScope(ctrl) + rcmgr.EXPECT().OpenConnection(network.DirOutbound, false, addr).Return(scope, nil) + scope.EXPECT().SetPeer(p).Return(errors.New("denied")) + scope.EXPECT().Done() + + _, err = tr.Dial(context.Background(), addr, p) + require.EqualError(t, err, "denied") +} From 468fd51967585a6da58fd63345082cf313c0a703 Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Sat, 9 Jul 2022 11:33:19 +0000 Subject: [PATCH 23/44] move the HTTP handler to a separate method on the listener --- p2p/transport/webtransport/listener.go | 40 ++++++++++++++------------ 1 file changed, 21 insertions(+), 19 deletions(-) diff --git a/p2p/transport/webtransport/listener.go b/p2p/transport/webtransport/listener.go index 91767dac1c..50af8879c1 100644 --- a/p2p/transport/webtransport/listener.go +++ b/p2p/transport/webtransport/listener.go @@ -81,25 +81,7 @@ func newListener(laddr ma.Multiaddr, transport tpt.Transport, noise *noise.Trans mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { w.Write([]byte("Hello, world!")) }) - mux.HandleFunc(webtransportHTTPEndpoint, func(w http.ResponseWriter, r *http.Request) { - // TODO: check ?type=multistream URL param - c, err := ln.server.Upgrade(w, r) - if err != nil { - w.WriteHeader(500) - return - } - ctx, cancel := context.WithTimeout(ln.ctx, handshakeTimeout) - conn, err := ln.handshake(ctx, c) - if err != nil { - cancel() - log.Debugw("handshake failed", "error", err) - c.Close() - return - } - cancel() - // TODO: think about what happens when this channel fills up - ln.queue <- conn - }) + mux.HandleFunc(webtransportHTTPEndpoint, ln.httpHandler) ln.server.H3.Handler = mux go func() { defer close(ln.serverClosed) @@ -112,6 +94,26 @@ func newListener(laddr ma.Multiaddr, transport tpt.Transport, noise *noise.Trans return ln, nil } +func (l *listener) httpHandler(w http.ResponseWriter, r *http.Request) { + // TODO: check ?type=multistream URL param + c, err := l.server.Upgrade(w, r) + if err != nil { + w.WriteHeader(500) + return + } + ctx, cancel := context.WithTimeout(l.ctx, handshakeTimeout) + conn, err := l.handshake(ctx, c) + if err != nil { + cancel() + log.Debugw("handshake failed", "error", err) + c.Close() + return + } + cancel() + // TODO: think about what happens when this channel fills up + l.queue <- conn +} + func (l *listener) Accept() (tpt.CapableConn, error) { select { case <-l.ctx.Done(): From 2110e680dd0f8350fa5563ceacb67acb8ac7f5b6 Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Sat, 9 Jul 2022 12:11:05 +0000 Subject: [PATCH 24/44] add a function to convert a IP:port string to a WebTransport multiaddr --- p2p/transport/webtransport/multiaddr.go | 17 +++++++++++++++++ p2p/transport/webtransport/multiaddr_test.go | 20 ++++++++++++++++++++ 2 files changed, 37 insertions(+) diff --git a/p2p/transport/webtransport/multiaddr.go b/p2p/transport/webtransport/multiaddr.go index 9d79991938..a61b1b8962 100644 --- a/p2p/transport/webtransport/multiaddr.go +++ b/p2p/transport/webtransport/multiaddr.go @@ -3,6 +3,7 @@ package libp2pwebtransport import ( "errors" "net" + "strconv" ma "github.com/multiformats/go-multiaddr" mafmt "github.com/multiformats/go-multiaddr-fmt" @@ -23,3 +24,19 @@ func toWebtransportMultiaddr(na net.Addr) (ma.Multiaddr, error) { } return addr.Encapsulate(webtransportMA), nil } + +func stringToWebtransportMultiaddr(str string) (ma.Multiaddr, error) { + host, portStr, err := net.SplitHostPort(str) + if err != nil { + return nil, err + } + port, err := strconv.ParseInt(portStr, 10, 32) + if err != nil { + return nil, err + } + ip := net.ParseIP(host) + if ip == nil { + return nil, errors.New("failed to parse IP") + } + return toWebtransportMultiaddr(&net.UDPAddr{IP: ip, Port: int(port)}) +} diff --git a/p2p/transport/webtransport/multiaddr_test.go b/p2p/transport/webtransport/multiaddr_test.go index 2098babbd0..5fc0432a33 100644 --- a/p2p/transport/webtransport/multiaddr_test.go +++ b/p2p/transport/webtransport/multiaddr_test.go @@ -19,3 +19,23 @@ func TestWebtransportMultiaddr(t *testing.T) { require.EqualError(t, err, "not a UDP address") }) } + +func TestWebtransportMultiaddrFromString(t *testing.T) { + t.Run("valid", func(t *testing.T) { + addr, err := stringToWebtransportMultiaddr("1.2.3.4:60042") + require.NoError(t, err) + require.Equal(t, "/ip4/1.2.3.4/udp/60042/quic/webtransport", addr.String()) + }) + + t.Run("invalid", func(t *testing.T) { + for _, addr := range [...]string{ + "1.2.3.4", // missing port + "1.2.3.4:123456", // invalid port + ":1234", // missing IP + "foobar", + } { + _, err := stringToWebtransportMultiaddr(addr) + require.Error(t, err) + } + }) +} From f55a4ffc9a6e7bae7be6d4ddfe5d400c0d8ca504 Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Sat, 9 Jul 2022 15:19:45 +0000 Subject: [PATCH 25/44] use the resource manager when accepting connections --- p2p/transport/webtransport/listener.go | 32 ++++++++++- p2p/transport/webtransport/transport.go | 2 +- p2p/transport/webtransport/transport_test.go | 56 ++++++++++++++++++++ 3 files changed, 88 insertions(+), 2 deletions(-) diff --git a/p2p/transport/webtransport/listener.go b/p2p/transport/webtransport/listener.go index 50af8879c1..ef21c23ffb 100644 --- a/p2p/transport/webtransport/listener.go +++ b/p2p/transport/webtransport/listener.go @@ -8,6 +8,7 @@ import ( "net/http" "time" + "github.com/libp2p/go-libp2p-core/network" tpt "github.com/libp2p/go-libp2p-core/transport" noise "github.com/libp2p/go-libp2p-noise" @@ -27,6 +28,7 @@ type listener struct { transport tpt.Transport noise *noise.Transport certManager *certManager + rcmgr network.ResourceManager server webtransport.Server @@ -43,7 +45,7 @@ type listener struct { var _ tpt.Listener = &listener{} -func newListener(laddr ma.Multiaddr, transport tpt.Transport, noise *noise.Transport, certManager *certManager) (tpt.Listener, error) { +func newListener(laddr ma.Multiaddr, transport tpt.Transport, noise *noise.Transport, certManager *certManager, rcmgr network.ResourceManager) (tpt.Listener, error) { network, addr, err := manet.DialArgs(laddr) if err != nil { return nil, err @@ -64,6 +66,7 @@ func newListener(laddr ma.Multiaddr, transport tpt.Transport, noise *noise.Trans transport: transport, noise: noise, certManager: certManager, + rcmgr: rcmgr, queue: make(chan tpt.CapableConn, queueLen), serverClosed: make(chan struct{}), addr: udpConn.LocalAddr(), @@ -95,10 +98,28 @@ func newListener(laddr ma.Multiaddr, transport tpt.Transport, noise *noise.Trans } func (l *listener) httpHandler(w http.ResponseWriter, r *http.Request) { + remoteMultiaddr, err := stringToWebtransportMultiaddr(r.RemoteAddr) + if err != nil { + // This should never happen. + log.Errorw("converting remote address failed", "remote", r.RemoteAddr, "error", err) + w.WriteHeader(http.StatusBadRequest) + return + } + + connScope, err := l.rcmgr.OpenConnection(network.DirInbound, false, remoteMultiaddr) + if err != nil { + log.Debugw("resource manager blocked incoming connection", "addr", r.RemoteAddr, "error", err) + w.WriteHeader(http.StatusServiceUnavailable) + return + } + // TODO: check ?type=multistream URL param c, err := l.server.Upgrade(w, r) if err != nil { + log.Debugw("upgrade failed", "error", err) + // TODO: think about the status code to use here w.WriteHeader(500) + connScope.Done() return } ctx, cancel := context.WithTimeout(l.ctx, handshakeTimeout) @@ -107,9 +128,18 @@ func (l *listener) httpHandler(w http.ResponseWriter, r *http.Request) { cancel() log.Debugw("handshake failed", "error", err) c.Close() + connScope.Done() return } cancel() + + if err := connScope.SetPeer(conn.RemotePeer()); err != nil { + log.Debugw("resource manager blocked incoming connection for peer", "peer", conn.RemotePeer(), "addr", r.RemoteAddr, "error", err) + conn.Close() + connScope.Done() + return + } + // TODO: think about what happens when this channel fills up l.queue <- conn } diff --git a/p2p/transport/webtransport/transport.go b/p2p/transport/webtransport/transport.go index 52b4e554a2..0d825076b2 100644 --- a/p2p/transport/webtransport/transport.go +++ b/p2p/transport/webtransport/transport.go @@ -193,7 +193,7 @@ func (t *transport) Listen(laddr ma.Multiaddr) (tpt.Listener, error) { if t.listenOnceErr != nil { return nil, t.listenOnceErr } - return newListener(laddr, t, t.noise, t.certManager) + return newListener(laddr, t, t.noise, t.certManager, t.rcmgr) } func (t *transport) Protocols() []int { diff --git a/p2p/transport/webtransport/transport_test.go b/p2p/transport/webtransport/transport_test.go index f27a4984e4..8f8b70fb8e 100644 --- a/p2p/transport/webtransport/transport_test.go +++ b/p2p/transport/webtransport/transport_test.go @@ -235,3 +235,59 @@ func TestResourceManagerDialing(t *testing.T) { _, err = tr.Dial(context.Background(), addr, p) require.EqualError(t, err, "denied") } + +func TestResourceManagerListening(t *testing.T) { + clientID, key := newIdentity(t) + cl, err := libp2pwebtransport.New(key, network.NullResourceManager) + require.NoError(t, err) + defer cl.(io.Closer).Close() + + t.Run("blocking the connection", func(t *testing.T) { + serverID, key := newIdentity(t) + ctrl := gomock.NewController(t) + defer ctrl.Finish() + rcmgr := mocknetwork.NewMockResourceManager(ctrl) + tr, err := libp2pwebtransport.New(key, rcmgr) + require.NoError(t, err) + ln, err := tr.Listen(ma.StringCast("/ip4/127.0.0.1/udp/0/quic/webtransport")) + require.NoError(t, err) + defer ln.Close() + + rcmgr.EXPECT().OpenConnection(network.DirInbound, false, gomock.Any()).DoAndReturn(func(_ network.Direction, _ bool, addr ma.Multiaddr) (network.ConnManagementScope, error) { + _, err := addr.ValueForProtocol(ma.P_WEBTRANSPORT) + require.NoError(t, err, "expected a WebTransport multiaddr") + _, addrStr, err := manet.DialArgs(addr) + require.NoError(t, err) + host, _, err := net.SplitHostPort(addrStr) + require.NoError(t, err) + require.Equal(t, "127.0.0.1", host) + return nil, errors.New("denied") + }) + + _, err = cl.Dial(context.Background(), ln.Multiaddr(), serverID) + require.EqualError(t, err, "received status 503") + }) + + t.Run("blocking the peer", func(t *testing.T) { + serverID, key := newIdentity(t) + ctrl := gomock.NewController(t) + defer ctrl.Finish() + rcmgr := mocknetwork.NewMockResourceManager(ctrl) + tr, err := libp2pwebtransport.New(key, rcmgr) + require.NoError(t, err) + ln, err := tr.Listen(ma.StringCast("/ip4/127.0.0.1/udp/0/quic/webtransport")) + require.NoError(t, err) + defer ln.Close() + + scope := mocknetwork.NewMockConnManagementScope(ctrl) + rcmgr.EXPECT().OpenConnection(network.DirInbound, false, gomock.Any()).Return(scope, nil) + scope.EXPECT().SetPeer(clientID).Return(errors.New("denied")) + scope.EXPECT().Done() + + // The handshake will complete, but the server will immediately close the connection. + conn, err := cl.Dial(context.Background(), ln.Multiaddr(), serverID) + require.NoError(t, err) + _, err = conn.AcceptStream() + require.Error(t, err) + }) +} From 3b13c8342d30ede876dac413c02b918599cb180e Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Sat, 9 Jul 2022 16:24:24 +0000 Subject: [PATCH 26/44] move extraction of certificate hashes to a separate function --- p2p/transport/webtransport/multiaddr.go | 26 +++++++++++ p2p/transport/webtransport/multiaddr_test.go | 47 ++++++++++++++++++++ p2p/transport/webtransport/transport.go | 22 ++------- p2p/transport/webtransport/transport_test.go | 6 +-- 4 files changed, 79 insertions(+), 22 deletions(-) diff --git a/p2p/transport/webtransport/multiaddr.go b/p2p/transport/webtransport/multiaddr.go index a61b1b8962..7b6df2ba25 100644 --- a/p2p/transport/webtransport/multiaddr.go +++ b/p2p/transport/webtransport/multiaddr.go @@ -2,12 +2,15 @@ package libp2pwebtransport import ( "errors" + "fmt" "net" "strconv" ma "github.com/multiformats/go-multiaddr" mafmt "github.com/multiformats/go-multiaddr-fmt" manet "github.com/multiformats/go-multiaddr/net" + "github.com/multiformats/go-multibase" + "github.com/multiformats/go-multihash" ) var webtransportMA = ma.StringCast("/quic/webtransport") @@ -40,3 +43,26 @@ func stringToWebtransportMultiaddr(str string) (ma.Multiaddr, error) { } return toWebtransportMultiaddr(&net.UDPAddr{IP: ip, Port: int(port)}) } + +func extractCertHashes(addr ma.Multiaddr) ([]multihash.DecodedMultihash, error) { + certHashesStr := make([]string, 0, 2) + ma.ForEach(addr, func(c ma.Component) bool { + if c.Protocol().Code == ma.P_CERTHASH { + certHashesStr = append(certHashesStr, c.Value()) + } + return true + }) + certHashes := make([]multihash.DecodedMultihash, 0, len(certHashesStr)) + for _, s := range certHashesStr { + _, ch, err := multibase.Decode(s) + if err != nil { + return nil, fmt.Errorf("failed to multibase-decode certificate hash: %w", err) + } + dh, err := multihash.Decode(ch) + if err != nil { + return nil, fmt.Errorf("failed to multihash-decode certificate hash: %w", err) + } + certHashes = append(certHashes, *dh) + } + return certHashes, nil +} diff --git a/p2p/transport/webtransport/multiaddr_test.go b/p2p/transport/webtransport/multiaddr_test.go index 5fc0432a33..7c95b00d01 100644 --- a/p2p/transport/webtransport/multiaddr_test.go +++ b/p2p/transport/webtransport/multiaddr_test.go @@ -1,9 +1,13 @@ package libp2pwebtransport import ( + "fmt" "net" "testing" + ma "github.com/multiformats/go-multiaddr" + "github.com/multiformats/go-multibase" + "github.com/multiformats/go-multihash" "github.com/stretchr/testify/require" ) @@ -39,3 +43,46 @@ func TestWebtransportMultiaddrFromString(t *testing.T) { } }) } + +func encodeCertHash(t *testing.T, b []byte, mh uint64, mb multibase.Encoding) string { + t.Helper() + h, err := multihash.Encode(b, mh) + require.NoError(t, err) + str, err := multibase.Encode(mb, h) + require.NoError(t, err) + return str +} + +func TestExtractCertHashes(t *testing.T) { + fooHash := encodeCertHash(t, []byte("foo"), multihash.SHA2_256, multibase.Base58BTC) + barHash := encodeCertHash(t, []byte("bar"), multihash.BLAKE2B_MAX, multibase.Base32) + + // valid cases + for _, tc := range [...]struct { + addr string + hashes []string + }{ + {addr: "/ip4/127.0.0.1/udp/1234/quic/webtransport"}, + {addr: fmt.Sprintf("/ip4/127.0.0.1/udp/1234/quic/webtransport/certhash/%s", fooHash), hashes: []string{"foo"}}, + {addr: fmt.Sprintf("/ip4/127.0.0.1/udp/1234/quic/webtransport/certhash/%s/certhash/%s", fooHash, barHash), hashes: []string{"foo", "bar"}}, + } { + ch, err := extractCertHashes(ma.StringCast(tc.addr)) + require.NoError(t, err) + require.Len(t, ch, len(tc.hashes)) + for i, h := range tc.hashes { + require.Equal(t, h, string(ch[i].Digest)) + } + } + + // invalid cases + for _, tc := range [...]struct { + addr string + err string + }{ + {addr: fmt.Sprintf("/ip4/127.0.0.1/udp/1234/quic/webtransport/certhash/%s", fooHash[:len(fooHash)-1]), err: "failed to multihash-decode certificate hash"}, + } { + _, err := extractCertHashes(ma.StringCast(tc.addr)) + require.Error(t, err) + require.Contains(t, err.Error(), tc.err) + } +} diff --git a/p2p/transport/webtransport/transport.go b/p2p/transport/webtransport/transport.go index 0d825076b2..2d2ffcfb6f 100644 --- a/p2p/transport/webtransport/transport.go +++ b/p2p/transport/webtransport/transport.go @@ -22,7 +22,6 @@ import ( "github.com/marten-seemann/webtransport-go" ma "github.com/multiformats/go-multiaddr" manet "github.com/multiformats/go-multiaddr/net" - "github.com/multiformats/go-multibase" "github.com/multiformats/go-multihash" ) @@ -99,24 +98,9 @@ func (t *transport) dial(ctx context.Context, raddr ma.Multiaddr, p peer.ID) (tp return nil, err } url := fmt.Sprintf("https://%s%s", addr, webtransportHTTPEndpoint) - certHashesStr := make([]string, 0, 2) - ma.ForEach(raddr, func(c ma.Component) bool { - if c.Protocol().Code == ma.P_CERTHASH { - certHashesStr = append(certHashesStr, c.Value()) - } - return true - }) - var certHashes []multihash.DecodedMultihash - for _, s := range certHashesStr { - _, ch, err := multibase.Decode(s) - if err != nil { - return nil, fmt.Errorf("failed to multibase-decode certificate hash: %w", err) - } - dh, err := multihash.Decode(ch) - if err != nil { - return nil, fmt.Errorf("failed to multihash-decode certificate hash: %w", err) - } - certHashes = append(certHashes, *dh) + certHashes, err := extractCertHashes(raddr) + if err != nil { + return nil, err } rsp, wconn, err := t.dialer.Dial(ctx, url, nil) if err != nil { diff --git a/p2p/transport/webtransport/transport_test.go b/p2p/transport/webtransport/transport_test.go index 8f8b70fb8e..115c7f1030 100644 --- a/p2p/transport/webtransport/transport_test.go +++ b/p2p/transport/webtransport/transport_test.go @@ -42,7 +42,7 @@ func randomMultihash(t *testing.T) string { return s } -func extractCertHashes(t *testing.T, addr ma.Multiaddr) []string { +func extractCertHashes(addr ma.Multiaddr) []string { var certHashesStr []string ma.ForEach(addr, func(c ma.Component) bool { if c.Protocol().Code == ma.P_CERTHASH { @@ -208,9 +208,9 @@ func TestListenerAddrs(t *testing.T) { require.NoError(t, err) ln2, err := tr.Listen(ma.StringCast("/ip4/127.0.0.1/udp/0/quic/webtransport")) require.NoError(t, err) - hashes1 := extractCertHashes(t, ln1.Multiaddr()) + hashes1 := extractCertHashes(ln1.Multiaddr()) require.Len(t, hashes1, 1) - hashes2 := extractCertHashes(t, ln2.Multiaddr()) + hashes2 := extractCertHashes(ln2.Multiaddr()) require.Equal(t, hashes1, hashes2) } From c7149b3be4b74a160c52271ccbda241d19e6219d Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Sat, 9 Jul 2022 18:49:40 +0000 Subject: [PATCH 27/44] pass the connection scope to the connection (#11) --- p2p/transport/webtransport/conn.go | 15 +++---- p2p/transport/webtransport/listener.go | 29 +++++++------ p2p/transport/webtransport/transport.go | 45 +++++++++++++------- p2p/transport/webtransport/transport_test.go | 15 ++++++- 4 files changed, 63 insertions(+), 41 deletions(-) diff --git a/p2p/transport/webtransport/conn.go b/p2p/transport/webtransport/conn.go index 90ae154b8c..3478d39db1 100644 --- a/p2p/transport/webtransport/conn.go +++ b/p2p/transport/webtransport/conn.go @@ -21,9 +21,10 @@ type conn struct { local, remote ma.Multiaddr privKey ic.PrivKey remotePubKey ic.PubKey + scope network.ConnScope } -func newConn(tr tpt.Transport, wsess *webtransport.Session, privKey ic.PrivKey, remotePubKey ic.PubKey) (*conn, error) { +func newConn(tr tpt.Transport, wsess *webtransport.Session, privKey ic.PrivKey, remotePubKey ic.PubKey, scope network.ConnScope) (*conn, error) { localPeer, err := peer.IDFromPrivateKey(privKey) if err != nil { return nil, err @@ -49,6 +50,7 @@ func newConn(tr tpt.Transport, wsess *webtransport.Session, privKey ic.PrivKey, remotePubKey: remotePubKey, local: local, remote: remote, + scope: scope, }, nil } @@ -78,12 +80,5 @@ func (c *conn) RemotePeer() peer.ID { return c.remotePeer } func (c *conn) RemotePublicKey() ic.PubKey { return c.remotePubKey } func (c *conn) LocalMultiaddr() ma.Multiaddr { return c.local } func (c *conn) RemoteMultiaddr() ma.Multiaddr { return c.remote } - -func (c *conn) Scope() network.ConnScope { - // TODO implement me - panic("implement me") -} - -func (c *conn) Transport() tpt.Transport { - return c.transport -} +func (c *conn) Scope() network.ConnScope { return c.scope } +func (c *conn) Transport() tpt.Transport { return c.transport } diff --git a/p2p/transport/webtransport/listener.go b/p2p/transport/webtransport/listener.go index ef21c23ffb..739c230ca6 100644 --- a/p2p/transport/webtransport/listener.go +++ b/p2p/transport/webtransport/listener.go @@ -114,7 +114,7 @@ func (l *listener) httpHandler(w http.ResponseWriter, r *http.Request) { } // TODO: check ?type=multistream URL param - c, err := l.server.Upgrade(w, r) + sess, err := l.server.Upgrade(w, r) if err != nil { log.Debugw("upgrade failed", "error", err) // TODO: think about the status code to use here @@ -123,25 +123,32 @@ func (l *listener) httpHandler(w http.ResponseWriter, r *http.Request) { return } ctx, cancel := context.WithTimeout(l.ctx, handshakeTimeout) - conn, err := l.handshake(ctx, c) + sconn, err := l.handshake(ctx, sess) if err != nil { cancel() log.Debugw("handshake failed", "error", err) - c.Close() + sess.Close() connScope.Done() return } cancel() - if err := connScope.SetPeer(conn.RemotePeer()); err != nil { - log.Debugw("resource manager blocked incoming connection for peer", "peer", conn.RemotePeer(), "addr", r.RemoteAddr, "error", err) - conn.Close() + if err := connScope.SetPeer(sconn.RemotePeer()); err != nil { + log.Debugw("resource manager blocked incoming connection for peer", "peer", sconn.RemotePeer(), "addr", r.RemoteAddr, "error", err) + sess.Close() + connScope.Done() + return + } + + c, err := newConn(l.transport, sess, sconn.LocalPrivateKey(), sconn.RemotePublicKey(), connScope) + if err != nil { + sess.Close() connScope.Done() return } // TODO: think about what happens when this channel fills up - l.queue <- conn + l.queue <- c } func (l *listener) Accept() (tpt.CapableConn, error) { @@ -153,16 +160,12 @@ func (l *listener) Accept() (tpt.CapableConn, error) { } } -func (l *listener) handshake(ctx context.Context, sess *webtransport.Session) (tpt.CapableConn, error) { +func (l *listener) handshake(ctx context.Context, sess *webtransport.Session) (network.ConnSecurity, error) { str, err := sess.AcceptStream(ctx) if err != nil { return nil, err } - conn, err := l.noise.SecureInbound(ctx, &webtransportStream{Stream: str, wsess: sess}, "") - if err != nil { - return nil, err - } - return newConn(l.transport, sess, conn.LocalPrivateKey(), conn.RemotePublicKey()) + return l.noise.SecureInbound(ctx, &webtransportStream{Stream: str, wsess: sess}, "") } func (l *listener) Addr() net.Addr { diff --git a/p2p/transport/webtransport/transport.go b/p2p/transport/webtransport/transport.go index 2d2ffcfb6f..640221a9b1 100644 --- a/p2p/transport/webtransport/transport.go +++ b/p2p/transport/webtransport/transport.go @@ -4,6 +4,7 @@ import ( "context" "crypto/tls" "fmt" + manet "github.com/multiformats/go-multiaddr/net" "io" "sync" "time" @@ -21,7 +22,6 @@ import ( "github.com/lucas-clemente/quic-go/http3" "github.com/marten-seemann/webtransport-go" ma "github.com/multiformats/go-multiaddr" - manet "github.com/multiformats/go-multiaddr/net" "github.com/multiformats/go-multihash" ) @@ -73,6 +73,15 @@ func New(key ic.PrivKey, rcmgr network.ResourceManager) (tpt.Transport, error) { } func (t *transport) Dial(ctx context.Context, raddr ma.Multiaddr, p peer.ID) (tpt.CapableConn, error) { + _, addr, err := manet.DialArgs(raddr) + if err != nil { + return nil, err + } + certHashes, err := extractCertHashes(raddr) + if err != nil { + return nil, err + } + scope, err := t.rcmgr.OpenConnection(network.DirOutbound, false, raddr) if err != nil { log.Debugw("resource manager blocked outgoing connection", "peer", p, "addr", raddr, "error", err) @@ -84,32 +93,40 @@ func (t *transport) Dial(ctx context.Context, raddr ma.Multiaddr, p peer.ID) (tp return nil, err } - conn, err := t.dial(ctx, raddr, p) + sess, err := t.dial(ctx, addr) if err != nil { scope.Done() return nil, err } - return conn, nil -} - -func (t *transport) dial(ctx context.Context, raddr ma.Multiaddr, p peer.ID) (tpt.CapableConn, error) { - _, addr, err := manet.DialArgs(raddr) + sconn, err := t.upgrade(ctx, sess, p, certHashes) if err != nil { + sess.Close() + scope.Done() return nil, err } - url := fmt.Sprintf("https://%s%s", addr, webtransportHTTPEndpoint) - certHashes, err := extractCertHashes(raddr) + c, err := newConn(t, sess, t.privKey, sconn.RemotePublicKey(), scope) if err != nil { + sess.Close() + scope.Done() return nil, err } - rsp, wconn, err := t.dialer.Dial(ctx, url, nil) + return c, nil +} + +func (t *transport) dial(ctx context.Context, addr string) (*webtransport.Session, error) { + url := fmt.Sprintf("https://%s%s", addr, webtransportHTTPEndpoint) + rsp, sess, err := t.dialer.Dial(ctx, url, nil) if err != nil { return nil, err } if rsp.StatusCode < 200 || rsp.StatusCode > 299 { return nil, fmt.Errorf("invalid response status code: %d", rsp.StatusCode) } - str, err := wconn.OpenStreamSync(ctx) + return sess, err +} + +func (t *transport) upgrade(ctx context.Context, sess *webtransport.Session, p peer.ID, certHashes []multihash.DecodedMultihash) (network.ConnSecurity, error) { + str, err := sess.OpenStreamSync(ctx) if err != nil { return nil, err } @@ -127,11 +144,7 @@ func (t *transport) dial(ctx context.Context, raddr ma.Multiaddr, p peer.ID) (tp if err != nil { return nil, fmt.Errorf("failed to marshal WebTransport protobuf: %w", err) } - sconn, err := t.noise.SecureOutboundWithEarlyData(ctx, &webtransportStream{Stream: str, wsess: wconn}, p, msgBytes) - if err != nil { - return nil, err - } - return newConn(t, wconn, t.privKey, sconn.RemotePublicKey()) + return t.noise.SecureOutboundWithEarlyData(ctx, &webtransportStream{Stream: str, wsess: sess}, p, msgBytes) } func (t *transport) checkEarlyData(b []byte) error { diff --git a/p2p/transport/webtransport/transport_test.go b/p2p/transport/webtransport/transport_test.go index 115c7f1030..7ddcff9ffa 100644 --- a/p2p/transport/webtransport/transport_test.go +++ b/p2p/transport/webtransport/transport_test.go @@ -9,6 +9,7 @@ import ( "io" "net" "testing" + "time" libp2pwebtransport "github.com/marten-seemann/go-libp2p-webtransport" @@ -287,7 +288,17 @@ func TestResourceManagerListening(t *testing.T) { // The handshake will complete, but the server will immediately close the connection. conn, err := cl.Dial(context.Background(), ln.Multiaddr(), serverID) require.NoError(t, err) - _, err = conn.AcceptStream() - require.Error(t, err) + defer conn.Close() + done := make(chan struct{}) + go func() { + defer close(done) + _, err = conn.AcceptStream() + require.Error(t, err) + }() + select { + case <-done: + case <-time.After(5 * time.Second): + t.Fatal("timeout") + } }) } From 69341f6e33172e8be86bcc5a7f3b6cd144a0710a Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Sat, 9 Jul 2022 21:44:22 +0000 Subject: [PATCH 28/44] properly implement conn.IsClosed (#13) --- p2p/transport/webtransport/conn.go | 24 ++++++++------------ p2p/transport/webtransport/transport_test.go | 3 +++ 2 files changed, 12 insertions(+), 15 deletions(-) diff --git a/p2p/transport/webtransport/conn.go b/p2p/transport/webtransport/conn.go index 3478d39db1..2efc263fbf 100644 --- a/p2p/transport/webtransport/conn.go +++ b/p2p/transport/webtransport/conn.go @@ -15,7 +15,7 @@ import ( type conn struct { transport tpt.Transport - wsess *webtransport.Session + session *webtransport.Session localPeer, remotePeer peer.ID local, remote ma.Multiaddr @@ -24,7 +24,7 @@ type conn struct { scope network.ConnScope } -func newConn(tr tpt.Transport, wsess *webtransport.Session, privKey ic.PrivKey, remotePubKey ic.PubKey, scope network.ConnScope) (*conn, error) { +func newConn(tr tpt.Transport, sess *webtransport.Session, privKey ic.PrivKey, remotePubKey ic.PubKey, scope network.ConnScope) (*conn, error) { localPeer, err := peer.IDFromPrivateKey(privKey) if err != nil { return nil, err @@ -33,17 +33,17 @@ func newConn(tr tpt.Transport, wsess *webtransport.Session, privKey ic.PrivKey, if err != nil { return nil, err } - local, err := toWebtransportMultiaddr(wsess.LocalAddr()) + local, err := toWebtransportMultiaddr(sess.LocalAddr()) if err != nil { return nil, fmt.Errorf("error determiniting local addr: %w", err) } - remote, err := toWebtransportMultiaddr(wsess.RemoteAddr()) + remote, err := toWebtransportMultiaddr(sess.RemoteAddr()) if err != nil { return nil, fmt.Errorf("error determiniting remote addr: %w", err) } return &conn{ transport: tr, - wsess: wsess, + session: sess, privKey: privKey, localPeer: localPeer, remotePeer: remotePeer, @@ -56,24 +56,18 @@ func newConn(tr tpt.Transport, wsess *webtransport.Session, privKey ic.PrivKey, var _ tpt.CapableConn = &conn{} -func (c *conn) Close() error { - return c.wsess.Close() -} - -func (c *conn) IsClosed() bool { - panic("implement me") -} - func (c *conn) OpenStream(ctx context.Context) (network.MuxedStream, error) { - str, err := c.wsess.OpenStreamSync(ctx) + str, err := c.session.OpenStreamSync(ctx) return &stream{str}, err } func (c *conn) AcceptStream() (network.MuxedStream, error) { - str, err := c.wsess.AcceptStream(context.Background()) + str, err := c.session.AcceptStream(context.Background()) return &stream{str}, err } +func (c *conn) Close() error { return c.session.Close() } +func (c *conn) IsClosed() bool { return c.session.Context().Err() != nil } func (c *conn) LocalPeer() peer.ID { return c.localPeer } func (c *conn) LocalPrivateKey() ic.PrivKey { return c.privKey } func (c *conn) RemotePeer() peer.ID { return c.remotePeer } diff --git a/p2p/transport/webtransport/transport_test.go b/p2p/transport/webtransport/transport_test.go index 7ddcff9ffa..397664c5b6 100644 --- a/p2p/transport/webtransport/transport_test.go +++ b/p2p/transport/webtransport/transport_test.go @@ -89,12 +89,15 @@ func TestTransport(t *testing.T) { conn, err := ln.Accept() require.NoError(t, err) + require.False(t, conn.IsClosed()) str, err := conn.AcceptStream() require.NoError(t, err) data, err := io.ReadAll(str) require.NoError(t, err) require.Equal(t, "foobar", string(data)) require.Equal(t, <-addrChan, conn.LocalMultiaddr()) + require.NoError(t, conn.Close()) + require.True(t, conn.IsClosed()) } func TestHashVerification(t *testing.T) { From ecc1eff49d77ac8e025662f9aed6d227f4d23e1d Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Sat, 9 Jul 2022 21:48:20 +0000 Subject: [PATCH 29/44] refactor the conn constructor --- p2p/transport/webtransport/conn.go | 64 ++++++------------------- p2p/transport/webtransport/listener.go | 32 +++++++++---- p2p/transport/webtransport/transport.go | 47 ++++++++++++++---- 3 files changed, 74 insertions(+), 69 deletions(-) diff --git a/p2p/transport/webtransport/conn.go b/p2p/transport/webtransport/conn.go index 2efc263fbf..a8f73b079d 100644 --- a/p2p/transport/webtransport/conn.go +++ b/p2p/transport/webtransport/conn.go @@ -2,60 +2,32 @@ package libp2pwebtransport import ( "context" - "fmt" - - ic "github.com/libp2p/go-libp2p-core/crypto" "github.com/libp2p/go-libp2p-core/network" - "github.com/libp2p/go-libp2p-core/peer" tpt "github.com/libp2p/go-libp2p-core/transport" "github.com/marten-seemann/webtransport-go" - ma "github.com/multiformats/go-multiaddr" ) type conn struct { + connSecurityMultiaddrs + transport tpt.Transport session *webtransport.Session - localPeer, remotePeer peer.ID - local, remote ma.Multiaddr - privKey ic.PrivKey - remotePubKey ic.PubKey - scope network.ConnScope + scope network.ConnScope } -func newConn(tr tpt.Transport, sess *webtransport.Session, privKey ic.PrivKey, remotePubKey ic.PubKey, scope network.ConnScope) (*conn, error) { - localPeer, err := peer.IDFromPrivateKey(privKey) - if err != nil { - return nil, err - } - remotePeer, err := peer.IDFromPublicKey(remotePubKey) - if err != nil { - return nil, err - } - local, err := toWebtransportMultiaddr(sess.LocalAddr()) - if err != nil { - return nil, fmt.Errorf("error determiniting local addr: %w", err) - } - remote, err := toWebtransportMultiaddr(sess.RemoteAddr()) - if err != nil { - return nil, fmt.Errorf("error determiniting remote addr: %w", err) - } +var _ tpt.CapableConn = &conn{} + +func newConn(tr tpt.Transport, sess *webtransport.Session, sconn connSecurityMultiaddrs, scope network.ConnScope) *conn { return &conn{ - transport: tr, - session: sess, - privKey: privKey, - localPeer: localPeer, - remotePeer: remotePeer, - remotePubKey: remotePubKey, - local: local, - remote: remote, - scope: scope, - }, nil + connSecurityMultiaddrs: sconn, + transport: tr, + session: sess, + scope: scope, + } } -var _ tpt.CapableConn = &conn{} - func (c *conn) OpenStream(ctx context.Context) (network.MuxedStream, error) { str, err := c.session.OpenStreamSync(ctx) return &stream{str}, err @@ -66,13 +38,7 @@ func (c *conn) AcceptStream() (network.MuxedStream, error) { return &stream{str}, err } -func (c *conn) Close() error { return c.session.Close() } -func (c *conn) IsClosed() bool { return c.session.Context().Err() != nil } -func (c *conn) LocalPeer() peer.ID { return c.localPeer } -func (c *conn) LocalPrivateKey() ic.PrivKey { return c.privKey } -func (c *conn) RemotePeer() peer.ID { return c.remotePeer } -func (c *conn) RemotePublicKey() ic.PubKey { return c.remotePubKey } -func (c *conn) LocalMultiaddr() ma.Multiaddr { return c.local } -func (c *conn) RemoteMultiaddr() ma.Multiaddr { return c.remote } -func (c *conn) Scope() network.ConnScope { return c.scope } -func (c *conn) Transport() tpt.Transport { return c.transport } +func (c *conn) Close() error { return c.session.Close() } +func (c *conn) IsClosed() bool { return c.session.Context().Err() != nil } +func (c *conn) Scope() network.ConnScope { return c.scope } +func (c *conn) Transport() tpt.Transport { return c.transport } diff --git a/p2p/transport/webtransport/listener.go b/p2p/transport/webtransport/listener.go index 739c230ca6..73794f2f3c 100644 --- a/p2p/transport/webtransport/listener.go +++ b/p2p/transport/webtransport/listener.go @@ -4,6 +4,7 @@ import ( "context" "crypto/tls" "errors" + "fmt" "net" "net/http" "time" @@ -140,15 +141,8 @@ func (l *listener) httpHandler(w http.ResponseWriter, r *http.Request) { return } - c, err := newConn(l.transport, sess, sconn.LocalPrivateKey(), sconn.RemotePublicKey(), connScope) - if err != nil { - sess.Close() - connScope.Done() - return - } - // TODO: think about what happens when this channel fills up - l.queue <- c + l.queue <- newConn(l.transport, sess, sconn, connScope) } func (l *listener) Accept() (tpt.CapableConn, error) { @@ -160,12 +154,30 @@ func (l *listener) Accept() (tpt.CapableConn, error) { } } -func (l *listener) handshake(ctx context.Context, sess *webtransport.Session) (network.ConnSecurity, error) { +func (l *listener) handshake(ctx context.Context, sess *webtransport.Session) (connSecurityMultiaddrs, error) { + local, err := toWebtransportMultiaddr(sess.LocalAddr()) + if err != nil { + return nil, fmt.Errorf("error determiniting local addr: %w", err) + } + remote, err := toWebtransportMultiaddr(sess.RemoteAddr()) + if err != nil { + return nil, fmt.Errorf("error determiniting remote addr: %w", err) + } + str, err := sess.AcceptStream(ctx) if err != nil { return nil, err } - return l.noise.SecureInbound(ctx, &webtransportStream{Stream: str, wsess: sess}, "") + c, err := l.noise.SecureInbound(ctx, &webtransportStream{Stream: str, wsess: sess}, "") + if err != nil { + return nil, err + } + + return &connSecurityMultiaddrsImpl{ + ConnSecurity: c, + local: local, + remote: remote, + }, nil } func (l *listener) Addr() net.Addr { diff --git a/p2p/transport/webtransport/transport.go b/p2p/transport/webtransport/transport.go index 640221a9b1..9fae27f27a 100644 --- a/p2p/transport/webtransport/transport.go +++ b/p2p/transport/webtransport/transport.go @@ -4,7 +4,6 @@ import ( "context" "crypto/tls" "fmt" - manet "github.com/multiformats/go-multiaddr/net" "io" "sync" "time" @@ -22,6 +21,7 @@ import ( "github.com/lucas-clemente/quic-go/http3" "github.com/marten-seemann/webtransport-go" ma "github.com/multiformats/go-multiaddr" + manet "github.com/multiformats/go-multiaddr/net" "github.com/multiformats/go-multihash" ) @@ -31,6 +31,21 @@ const webtransportHTTPEndpoint = "/.well-known/libp2p-webtransport" const certValidity = 14 * 24 * time.Hour +type connSecurityMultiaddrs interface { + network.ConnMultiaddrs + network.ConnSecurity +} + +type connSecurityMultiaddrsImpl struct { + network.ConnSecurity + local, remote ma.Multiaddr +} + +var _ connSecurityMultiaddrs = &connSecurityMultiaddrsImpl{} + +func (c *connSecurityMultiaddrsImpl) LocalMultiaddr() ma.Multiaddr { return c.local } +func (c *connSecurityMultiaddrsImpl) RemoteMultiaddr() ma.Multiaddr { return c.remote } + type transport struct { privKey ic.PrivKey pid peer.ID @@ -104,13 +119,8 @@ func (t *transport) Dial(ctx context.Context, raddr ma.Multiaddr, p peer.ID) (tp scope.Done() return nil, err } - c, err := newConn(t, sess, t.privKey, sconn.RemotePublicKey(), scope) - if err != nil { - sess.Close() - scope.Done() - return nil, err - } - return c, nil + + return newConn(t, sess, sconn, scope), nil } func (t *transport) dial(ctx context.Context, addr string) (*webtransport.Session, error) { @@ -125,7 +135,16 @@ func (t *transport) dial(ctx context.Context, addr string) (*webtransport.Sessio return sess, err } -func (t *transport) upgrade(ctx context.Context, sess *webtransport.Session, p peer.ID, certHashes []multihash.DecodedMultihash) (network.ConnSecurity, error) { +func (t *transport) upgrade(ctx context.Context, sess *webtransport.Session, p peer.ID, certHashes []multihash.DecodedMultihash) (connSecurityMultiaddrs, error) { + local, err := toWebtransportMultiaddr(sess.LocalAddr()) + if err != nil { + return nil, fmt.Errorf("error determiniting local addr: %w", err) + } + remote, err := toWebtransportMultiaddr(sess.RemoteAddr()) + if err != nil { + return nil, fmt.Errorf("error determiniting remote addr: %w", err) + } + str, err := sess.OpenStreamSync(ctx) if err != nil { return nil, err @@ -144,7 +163,15 @@ func (t *transport) upgrade(ctx context.Context, sess *webtransport.Session, p p if err != nil { return nil, fmt.Errorf("failed to marshal WebTransport protobuf: %w", err) } - return t.noise.SecureOutboundWithEarlyData(ctx, &webtransportStream{Stream: str, wsess: sess}, p, msgBytes) + c, err := t.noise.SecureOutboundWithEarlyData(ctx, &webtransportStream{Stream: str, wsess: sess}, p, msgBytes) + if err != nil { + return nil, err + } + return &connSecurityMultiaddrsImpl{ + ConnSecurity: c, + local: local, + remote: remote, + }, nil } func (t *transport) checkEarlyData(b []byte) error { From 2c6fc83df04675e5cef93d6461d37bf898bcd6c8 Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Sat, 9 Jul 2022 22:00:49 +0000 Subject: [PATCH 30/44] implement connection gating for dials --- .../mock_connection_gater_test.go | 109 ++++++++++++++++++ p2p/transport/webtransport/transport.go | 10 +- p2p/transport/webtransport/transport_test.go | 57 +++++++-- 3 files changed, 164 insertions(+), 12 deletions(-) create mode 100644 p2p/transport/webtransport/mock_connection_gater_test.go diff --git a/p2p/transport/webtransport/mock_connection_gater_test.go b/p2p/transport/webtransport/mock_connection_gater_test.go new file mode 100644 index 0000000000..071ed7494f --- /dev/null +++ b/p2p/transport/webtransport/mock_connection_gater_test.go @@ -0,0 +1,109 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: github.com/libp2p/go-libp2p-core/connmgr (interfaces: ConnectionGater) + +// Package libp2pwebtransport_test is a generated GoMock package. +package libp2pwebtransport_test + +import ( + reflect "reflect" + + gomock "github.com/golang/mock/gomock" + control "github.com/libp2p/go-libp2p-core/control" + network "github.com/libp2p/go-libp2p-core/network" + peer "github.com/libp2p/go-libp2p-core/peer" + multiaddr "github.com/multiformats/go-multiaddr" +) + +// MockConnectionGater is a mock of ConnectionGater interface. +type MockConnectionGater struct { + ctrl *gomock.Controller + recorder *MockConnectionGaterMockRecorder +} + +// MockConnectionGaterMockRecorder is the mock recorder for MockConnectionGater. +type MockConnectionGaterMockRecorder struct { + mock *MockConnectionGater +} + +// NewMockConnectionGater creates a new mock instance. +func NewMockConnectionGater(ctrl *gomock.Controller) *MockConnectionGater { + mock := &MockConnectionGater{ctrl: ctrl} + mock.recorder = &MockConnectionGaterMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockConnectionGater) EXPECT() *MockConnectionGaterMockRecorder { + return m.recorder +} + +// InterceptAccept mocks base method. +func (m *MockConnectionGater) InterceptAccept(arg0 network.ConnMultiaddrs) bool { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "InterceptAccept", arg0) + ret0, _ := ret[0].(bool) + return ret0 +} + +// InterceptAccept indicates an expected call of InterceptAccept. +func (mr *MockConnectionGaterMockRecorder) InterceptAccept(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "InterceptAccept", reflect.TypeOf((*MockConnectionGater)(nil).InterceptAccept), arg0) +} + +// InterceptAddrDial mocks base method. +func (m *MockConnectionGater) InterceptAddrDial(arg0 peer.ID, arg1 multiaddr.Multiaddr) bool { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "InterceptAddrDial", arg0, arg1) + ret0, _ := ret[0].(bool) + return ret0 +} + +// InterceptAddrDial indicates an expected call of InterceptAddrDial. +func (mr *MockConnectionGaterMockRecorder) InterceptAddrDial(arg0, arg1 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "InterceptAddrDial", reflect.TypeOf((*MockConnectionGater)(nil).InterceptAddrDial), arg0, arg1) +} + +// InterceptPeerDial mocks base method. +func (m *MockConnectionGater) InterceptPeerDial(arg0 peer.ID) bool { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "InterceptPeerDial", arg0) + ret0, _ := ret[0].(bool) + return ret0 +} + +// InterceptPeerDial indicates an expected call of InterceptPeerDial. +func (mr *MockConnectionGaterMockRecorder) InterceptPeerDial(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "InterceptPeerDial", reflect.TypeOf((*MockConnectionGater)(nil).InterceptPeerDial), arg0) +} + +// InterceptSecured mocks base method. +func (m *MockConnectionGater) InterceptSecured(arg0 network.Direction, arg1 peer.ID, arg2 network.ConnMultiaddrs) bool { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "InterceptSecured", arg0, arg1, arg2) + ret0, _ := ret[0].(bool) + return ret0 +} + +// InterceptSecured indicates an expected call of InterceptSecured. +func (mr *MockConnectionGaterMockRecorder) InterceptSecured(arg0, arg1, arg2 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "InterceptSecured", reflect.TypeOf((*MockConnectionGater)(nil).InterceptSecured), arg0, arg1, arg2) +} + +// InterceptUpgraded mocks base method. +func (m *MockConnectionGater) InterceptUpgraded(arg0 network.Conn) (bool, control.DisconnectReason) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "InterceptUpgraded", arg0) + ret0, _ := ret[0].(bool) + ret1, _ := ret[1].(control.DisconnectReason) + return ret0, ret1 +} + +// InterceptUpgraded indicates an expected call of InterceptUpgraded. +func (mr *MockConnectionGaterMockRecorder) InterceptUpgraded(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "InterceptUpgraded", reflect.TypeOf((*MockConnectionGater)(nil).InterceptUpgraded), arg0) +} diff --git a/p2p/transport/webtransport/transport.go b/p2p/transport/webtransport/transport.go index 9fae27f27a..834412df36 100644 --- a/p2p/transport/webtransport/transport.go +++ b/p2p/transport/webtransport/transport.go @@ -10,6 +10,7 @@ import ( pb "github.com/marten-seemann/go-libp2p-webtransport/pb" + "github.com/libp2p/go-libp2p-core/connmgr" ic "github.com/libp2p/go-libp2p-core/crypto" "github.com/libp2p/go-libp2p-core/network" "github.com/libp2p/go-libp2p-core/peer" @@ -53,6 +54,7 @@ type transport struct { dialer webtransport.Dialer rcmgr network.ResourceManager + gater connmgr.ConnectionGater listenOnce sync.Once listenOnceErr error @@ -64,7 +66,7 @@ type transport struct { var _ tpt.Transport = &transport{} var _ io.Closer = &transport{} -func New(key ic.PrivKey, rcmgr network.ResourceManager) (tpt.Transport, error) { +func New(key ic.PrivKey, gater connmgr.ConnectionGater, rcmgr network.ResourceManager) (tpt.Transport, error) { id, err := peer.IDFromPrivateKey(key) if err != nil { return nil, err @@ -73,6 +75,7 @@ func New(key ic.PrivKey, rcmgr network.ResourceManager) (tpt.Transport, error) { pid: id, privKey: key, rcmgr: rcmgr, + gater: gater, dialer: webtransport.Dialer{ RoundTripper: &http3.RoundTripper{ TLSClientConfig: &tls.Config{InsecureSkipVerify: true}, // TODO: verify certificate, @@ -119,6 +122,11 @@ func (t *transport) Dial(ctx context.Context, raddr ma.Multiaddr, p peer.ID) (tp scope.Done() return nil, err } + if t.gater != nil && !t.gater.InterceptSecured(network.DirOutbound, p, sconn) { + // TODO: can we close with a specific error here? + sess.Close() + return nil, fmt.Errorf("secured connection gated") + } return newConn(t, sess, sconn, scope), nil } diff --git a/p2p/transport/webtransport/transport_test.go b/p2p/transport/webtransport/transport_test.go index 397664c5b6..ff73e0f9d8 100644 --- a/p2p/transport/webtransport/transport_test.go +++ b/p2p/transport/webtransport/transport_test.go @@ -56,7 +56,7 @@ func extractCertHashes(addr ma.Multiaddr) []string { func TestTransport(t *testing.T) { serverID, serverKey := newIdentity(t) - tr, err := libp2pwebtransport.New(serverKey, network.NullResourceManager) + tr, err := libp2pwebtransport.New(serverKey, nil, network.NullResourceManager) require.NoError(t, err) defer tr.(io.Closer).Close() ln, err := tr.Listen(ma.StringCast("/ip4/127.0.0.1/udp/0/quic/webtransport")) @@ -66,7 +66,7 @@ func TestTransport(t *testing.T) { addrChan := make(chan ma.Multiaddr) go func() { _, clientKey := newIdentity(t) - tr2, err := libp2pwebtransport.New(clientKey, network.NullResourceManager) + tr2, err := libp2pwebtransport.New(clientKey, nil, network.NullResourceManager) require.NoError(t, err) defer tr2.(io.Closer).Close() @@ -102,7 +102,7 @@ func TestTransport(t *testing.T) { func TestHashVerification(t *testing.T) { serverID, serverKey := newIdentity(t) - tr, err := libp2pwebtransport.New(serverKey, network.NullResourceManager) + tr, err := libp2pwebtransport.New(serverKey, nil, network.NullResourceManager) require.NoError(t, err) defer tr.(io.Closer).Close() ln, err := tr.Listen(ma.StringCast("/ip4/127.0.0.1/udp/0/quic/webtransport")) @@ -115,7 +115,7 @@ func TestHashVerification(t *testing.T) { }() _, clientKey := newIdentity(t) - tr2, err := libp2pwebtransport.New(clientKey, network.NullResourceManager) + tr2, err := libp2pwebtransport.New(clientKey, nil, network.NullResourceManager) require.NoError(t, err) defer tr2.(io.Closer).Close() @@ -161,7 +161,7 @@ func TestCanDial(t *testing.T) { } _, key := newIdentity(t) - tr, err := libp2pwebtransport.New(key, network.NullResourceManager) + tr, err := libp2pwebtransport.New(key, nil, network.NullResourceManager) require.NoError(t, err) defer tr.(io.Closer).Close() @@ -187,7 +187,7 @@ func TestListenAddrValidity(t *testing.T) { } _, key := newIdentity(t) - tr, err := libp2pwebtransport.New(key, network.NullResourceManager) + tr, err := libp2pwebtransport.New(key, nil, network.NullResourceManager) require.NoError(t, err) defer tr.(io.Closer).Close() @@ -204,7 +204,7 @@ func TestListenAddrValidity(t *testing.T) { func TestListenerAddrs(t *testing.T) { _, key := newIdentity(t) - tr, err := libp2pwebtransport.New(key, network.NullResourceManager) + tr, err := libp2pwebtransport.New(key, nil, network.NullResourceManager) require.NoError(t, err) defer tr.(io.Closer).Close() @@ -227,7 +227,7 @@ func TestResourceManagerDialing(t *testing.T) { p := peer.ID("foobar") _, key := newIdentity(t) - tr, err := libp2pwebtransport.New(key, rcmgr) + tr, err := libp2pwebtransport.New(key, nil, rcmgr) require.NoError(t, err) defer tr.(io.Closer).Close() @@ -242,7 +242,7 @@ func TestResourceManagerDialing(t *testing.T) { func TestResourceManagerListening(t *testing.T) { clientID, key := newIdentity(t) - cl, err := libp2pwebtransport.New(key, network.NullResourceManager) + cl, err := libp2pwebtransport.New(key, nil, network.NullResourceManager) require.NoError(t, err) defer cl.(io.Closer).Close() @@ -251,7 +251,7 @@ func TestResourceManagerListening(t *testing.T) { ctrl := gomock.NewController(t) defer ctrl.Finish() rcmgr := mocknetwork.NewMockResourceManager(ctrl) - tr, err := libp2pwebtransport.New(key, rcmgr) + tr, err := libp2pwebtransport.New(key, nil, rcmgr) require.NoError(t, err) ln, err := tr.Listen(ma.StringCast("/ip4/127.0.0.1/udp/0/quic/webtransport")) require.NoError(t, err) @@ -277,7 +277,7 @@ func TestResourceManagerListening(t *testing.T) { ctrl := gomock.NewController(t) defer ctrl.Finish() rcmgr := mocknetwork.NewMockResourceManager(ctrl) - tr, err := libp2pwebtransport.New(key, rcmgr) + tr, err := libp2pwebtransport.New(key, nil, rcmgr) require.NoError(t, err) ln, err := tr.Listen(ma.StringCast("/ip4/127.0.0.1/udp/0/quic/webtransport")) require.NoError(t, err) @@ -305,3 +305,38 @@ func TestResourceManagerListening(t *testing.T) { } }) } + +// TODO: unify somehow. We do the same in libp2pquic. +//go:generate sh -c "mockgen -package libp2pwebtransport_test -destination mock_connection_gater_test.go github.com/libp2p/go-libp2p-core/connmgr ConnectionGater && goimports -w mock_connection_gater_test.go" + +func TestConnectionGaterDialing(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + connGater := NewMockConnectionGater(ctrl) + + serverID, serverKey := newIdentity(t) + tr, err := libp2pwebtransport.New(serverKey, nil, network.NullResourceManager) + require.NoError(t, err) + defer tr.(io.Closer).Close() + ln, err := tr.Listen(ma.StringCast("/ip4/127.0.0.1/udp/0/quic/webtransport")) + require.NoError(t, err) + defer ln.Close() + + connGater.EXPECT().InterceptSecured(network.DirOutbound, serverID, gomock.Any()).Do(func(_ network.Direction, _ peer.ID, addrs network.ConnMultiaddrs) { + expected := ln.Multiaddr() + for { + _, err := expected.ValueForProtocol(ma.P_CERTHASH) + if err != nil { + break + } + expected, _ = ma.SplitLast(expected) + } + require.Equal(t, expected, addrs.RemoteMultiaddr()) + }) + _, key := newIdentity(t) + cl, err := libp2pwebtransport.New(key, connGater, network.NullResourceManager) + require.NoError(t, err) + defer cl.(io.Closer).Close() + _, err = cl.Dial(context.Background(), ln.Multiaddr(), serverID) + require.EqualError(t, err, "secured connection gated") +} From a508e940049967c94b0a74d008a31f83eb32a55b Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Sat, 9 Jul 2022 22:21:26 +0000 Subject: [PATCH 31/44] implement InterceptAccept connection gating --- p2p/transport/webtransport/listener.go | 14 ++++-- p2p/transport/webtransport/transport.go | 17 +++++--- p2p/transport/webtransport/transport_test.go | 46 ++++++++++++++++---- 3 files changed, 57 insertions(+), 20 deletions(-) diff --git a/p2p/transport/webtransport/listener.go b/p2p/transport/webtransport/listener.go index 73794f2f3c..9409612dcf 100644 --- a/p2p/transport/webtransport/listener.go +++ b/p2p/transport/webtransport/listener.go @@ -9,6 +9,7 @@ import ( "net/http" "time" + "github.com/libp2p/go-libp2p-core/connmgr" "github.com/libp2p/go-libp2p-core/network" tpt "github.com/libp2p/go-libp2p-core/transport" @@ -30,6 +31,7 @@ type listener struct { noise *noise.Transport certManager *certManager rcmgr network.ResourceManager + gater connmgr.ConnectionGater server webtransport.Server @@ -46,7 +48,7 @@ type listener struct { var _ tpt.Listener = &listener{} -func newListener(laddr ma.Multiaddr, transport tpt.Transport, noise *noise.Transport, certManager *certManager, rcmgr network.ResourceManager) (tpt.Listener, error) { +func newListener(laddr ma.Multiaddr, transport tpt.Transport, noise *noise.Transport, certManager *certManager, gater connmgr.ConnectionGater, rcmgr network.ResourceManager) (tpt.Listener, error) { network, addr, err := manet.DialArgs(laddr) if err != nil { return nil, err @@ -68,6 +70,7 @@ func newListener(laddr ma.Multiaddr, transport tpt.Transport, noise *noise.Trans noise: noise, certManager: certManager, rcmgr: rcmgr, + gater: gater, queue: make(chan tpt.CapableConn, queueLen), serverClosed: make(chan struct{}), addr: udpConn.LocalAddr(), @@ -106,6 +109,10 @@ func (l *listener) httpHandler(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusBadRequest) return } + if l.gater != nil && !l.gater.InterceptAccept(&connMultiaddrs{local: l.multiaddr, remote: remoteMultiaddr}) { + w.WriteHeader(http.StatusForbidden) + return + } connScope, err := l.rcmgr.OpenConnection(network.DirInbound, false, remoteMultiaddr) if err != nil { @@ -174,9 +181,8 @@ func (l *listener) handshake(ctx context.Context, sess *webtransport.Session) (c } return &connSecurityMultiaddrsImpl{ - ConnSecurity: c, - local: local, - remote: remote, + ConnSecurity: c, + ConnMultiaddrs: &connMultiaddrs{local: local, remote: remote}, }, nil } diff --git a/p2p/transport/webtransport/transport.go b/p2p/transport/webtransport/transport.go index 834412df36..c20c4c47ed 100644 --- a/p2p/transport/webtransport/transport.go +++ b/p2p/transport/webtransport/transport.go @@ -39,13 +39,17 @@ type connSecurityMultiaddrs interface { type connSecurityMultiaddrsImpl struct { network.ConnSecurity + network.ConnMultiaddrs +} + +type connMultiaddrs struct { local, remote ma.Multiaddr } -var _ connSecurityMultiaddrs = &connSecurityMultiaddrsImpl{} +var _ network.ConnMultiaddrs = &connMultiaddrs{} -func (c *connSecurityMultiaddrsImpl) LocalMultiaddr() ma.Multiaddr { return c.local } -func (c *connSecurityMultiaddrsImpl) RemoteMultiaddr() ma.Multiaddr { return c.remote } +func (c *connMultiaddrs) LocalMultiaddr() ma.Multiaddr { return c.local } +func (c *connMultiaddrs) RemoteMultiaddr() ma.Multiaddr { return c.remote } type transport struct { privKey ic.PrivKey @@ -176,9 +180,8 @@ func (t *transport) upgrade(ctx context.Context, sess *webtransport.Session, p p return nil, err } return &connSecurityMultiaddrsImpl{ - ConnSecurity: c, - local: local, - remote: remote, + ConnSecurity: c, + ConnMultiaddrs: &connMultiaddrs{local: local, remote: remote}, }, nil } @@ -225,7 +228,7 @@ func (t *transport) Listen(laddr ma.Multiaddr) (tpt.Listener, error) { if t.listenOnceErr != nil { return nil, t.listenOnceErr } - return newListener(laddr, t, t.noise, t.certManager, t.rcmgr) + return newListener(laddr, t, t.noise, t.certManager, t.gater, t.rcmgr) } func (t *transport) Protocols() []int { diff --git a/p2p/transport/webtransport/transport_test.go b/p2p/transport/webtransport/transport_test.go index ff73e0f9d8..c317adbd63 100644 --- a/p2p/transport/webtransport/transport_test.go +++ b/p2p/transport/webtransport/transport_test.go @@ -54,6 +54,16 @@ func extractCertHashes(addr ma.Multiaddr) []string { return certHashesStr } +func stripCertHashes(addr ma.Multiaddr) ma.Multiaddr { + for { + _, err := addr.ValueForProtocol(ma.P_CERTHASH) + if err != nil { + return addr + } + addr, _ = ma.SplitLast(addr) + } +} + func TestTransport(t *testing.T) { serverID, serverKey := newIdentity(t) tr, err := libp2pwebtransport.New(serverKey, nil, network.NullResourceManager) @@ -323,15 +333,7 @@ func TestConnectionGaterDialing(t *testing.T) { defer ln.Close() connGater.EXPECT().InterceptSecured(network.DirOutbound, serverID, gomock.Any()).Do(func(_ network.Direction, _ peer.ID, addrs network.ConnMultiaddrs) { - expected := ln.Multiaddr() - for { - _, err := expected.ValueForProtocol(ma.P_CERTHASH) - if err != nil { - break - } - expected, _ = ma.SplitLast(expected) - } - require.Equal(t, expected, addrs.RemoteMultiaddr()) + require.Equal(t, stripCertHashes(ln.Multiaddr()), addrs.RemoteMultiaddr()) }) _, key := newIdentity(t) cl, err := libp2pwebtransport.New(key, connGater, network.NullResourceManager) @@ -340,3 +342,29 @@ func TestConnectionGaterDialing(t *testing.T) { _, err = cl.Dial(context.Background(), ln.Multiaddr(), serverID) require.EqualError(t, err, "secured connection gated") } + +func TestConnectionGaterInterceptAccept(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + connGater := NewMockConnectionGater(ctrl) + + serverID, serverKey := newIdentity(t) + tr, err := libp2pwebtransport.New(serverKey, connGater, network.NullResourceManager) + require.NoError(t, err) + defer tr.(io.Closer).Close() + ln, err := tr.Listen(ma.StringCast("/ip4/127.0.0.1/udp/0/quic/webtransport")) + require.NoError(t, err) + defer ln.Close() + + connGater.EXPECT().InterceptAccept(gomock.Any()).Do(func(addrs network.ConnMultiaddrs) { + require.Equal(t, stripCertHashes(ln.Multiaddr()), addrs.LocalMultiaddr()) + require.NotEqual(t, stripCertHashes(ln.Multiaddr()), addrs.RemoteMultiaddr()) + }) + + _, key := newIdentity(t) + cl, err := libp2pwebtransport.New(key, nil, network.NullResourceManager) + require.NoError(t, err) + defer cl.(io.Closer).Close() + _, err = cl.Dial(context.Background(), ln.Multiaddr(), serverID) + require.EqualError(t, err, "received status 403") +} From 7e8ca3ac06fbe0c514d3a7f95758b69c8fd33074 Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Sat, 9 Jul 2022 22:25:58 +0000 Subject: [PATCH 32/44] implement InterceptSecured for accepted connections --- p2p/transport/webtransport/listener.go | 7 ++++ p2p/transport/webtransport/transport_test.go | 40 ++++++++++++++++++++ 2 files changed, 47 insertions(+) diff --git a/p2p/transport/webtransport/listener.go b/p2p/transport/webtransport/listener.go index 9409612dcf..bb762f7d7a 100644 --- a/p2p/transport/webtransport/listener.go +++ b/p2p/transport/webtransport/listener.go @@ -141,6 +141,13 @@ func (l *listener) httpHandler(w http.ResponseWriter, r *http.Request) { } cancel() + if l.gater != nil && !l.gater.InterceptSecured(network.DirInbound, sconn.RemotePeer(), sconn) { + // TODO: can we close with a specific error here? + sess.Close() + connScope.Done() + return + } + if err := connScope.SetPeer(sconn.RemotePeer()); err != nil { log.Debugw("resource manager blocked incoming connection for peer", "peer", sconn.RemotePeer(), "addr", r.RemoteAddr, "error", err) sess.Close() diff --git a/p2p/transport/webtransport/transport_test.go b/p2p/transport/webtransport/transport_test.go index c317adbd63..8540cabd9b 100644 --- a/p2p/transport/webtransport/transport_test.go +++ b/p2p/transport/webtransport/transport_test.go @@ -368,3 +368,43 @@ func TestConnectionGaterInterceptAccept(t *testing.T) { _, err = cl.Dial(context.Background(), ln.Multiaddr(), serverID) require.EqualError(t, err, "received status 403") } + +func TestConnectionGaterInterceptSecured(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + connGater := NewMockConnectionGater(ctrl) + + serverID, serverKey := newIdentity(t) + tr, err := libp2pwebtransport.New(serverKey, connGater, network.NullResourceManager) + require.NoError(t, err) + defer tr.(io.Closer).Close() + ln, err := tr.Listen(ma.StringCast("/ip4/127.0.0.1/udp/0/quic/webtransport")) + require.NoError(t, err) + defer ln.Close() + + clientID, key := newIdentity(t) + cl, err := libp2pwebtransport.New(key, nil, network.NullResourceManager) + require.NoError(t, err) + defer cl.(io.Closer).Close() + + connGater.EXPECT().InterceptAccept(gomock.Any()).Return(true) + connGater.EXPECT().InterceptSecured(network.DirInbound, clientID, gomock.Any()).Do(func(_ network.Direction, _ peer.ID, addrs network.ConnMultiaddrs) { + require.Equal(t, stripCertHashes(ln.Multiaddr()), addrs.LocalMultiaddr()) + require.NotEqual(t, stripCertHashes(ln.Multiaddr()), addrs.RemoteMultiaddr()) + }) + // The handshake will complete, but the server will immediately close the connection. + conn, err := cl.Dial(context.Background(), ln.Multiaddr(), serverID) + require.NoError(t, err) + defer conn.Close() + done := make(chan struct{}) + go func() { + defer close(done) + _, err = conn.AcceptStream() + require.Error(t, err) + }() + select { + case <-done: + case <-time.After(5 * time.Second): + t.Fatal("timeout") + } +} From 851e6ba1716c978e91c48966bf35f476755d0d7f Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Sun, 10 Jul 2022 12:36:39 +0000 Subject: [PATCH 33/44] verify the hash of the server's certificate (#16) --- .../webtransport/cert_manager_test.go | 14 +- p2p/transport/webtransport/crypto.go | 50 ++++++- p2p/transport/webtransport/crypto_test.go | 128 ++++++++++++++++++ p2p/transport/webtransport/transport.go | 24 ++-- p2p/transport/webtransport/transport_test.go | 1 + 5 files changed, 199 insertions(+), 18 deletions(-) create mode 100644 p2p/transport/webtransport/crypto_test.go diff --git a/p2p/transport/webtransport/cert_manager_test.go b/p2p/transport/webtransport/cert_manager_test.go index 3d9c881d4e..cb55a56ddc 100644 --- a/p2p/transport/webtransport/cert_manager_test.go +++ b/p2p/transport/webtransport/cert_manager_test.go @@ -1,16 +1,22 @@ package libp2pwebtransport import ( - "github.com/multiformats/go-multibase" - "github.com/multiformats/go-multihash" + "crypto/sha256" + "crypto/tls" "os" "testing" "time" ma "github.com/multiformats/go-multiaddr" + "github.com/multiformats/go-multibase" + "github.com/multiformats/go-multihash" "github.com/stretchr/testify/require" ) +func certificateHashFromTLSConfig(c *tls.Config) [32]byte { + return sha256.Sum256(c.Certificates[0].Certificate[0]) +} + func splitMultiaddr(addr ma.Multiaddr) []ma.Component { var components []ma.Component ma.ForEach(addr, func(c ma.Component) bool { @@ -44,7 +50,7 @@ func TestInitialCert(t *testing.T) { components := splitMultiaddr(addr) require.Len(t, components, 1) require.Equal(t, ma.P_CERTHASH, components[0].Protocol().Code) - hash := certificateHash(conf) + hash := certificateHashFromTLSConfig(conf) require.Equal(t, hash[:], certHashFromComponent(t, components[0])) } @@ -71,6 +77,6 @@ func TestCertRenewal(t *testing.T) { require.Eventually(t, func() bool { return m.GetConfig() != firstConf }, certValidity/2, 10*time.Millisecond) newConf := m.GetConfig() // check that the new config now matches the second component - hash := certificateHash(newConf) + hash := certificateHashFromTLSConfig(newConf) require.Equal(t, hash[:], certHashFromComponent(t, components[1])) } diff --git a/p2p/transport/webtransport/crypto.go b/p2p/transport/webtransport/crypto.go index f30943c6f6..9bb7f7a330 100644 --- a/p2p/transport/webtransport/crypto.go +++ b/p2p/transport/webtransport/crypto.go @@ -1,6 +1,7 @@ package libp2pwebtransport import ( + "bytes" "crypto/ecdsa" "crypto/elliptic" "crypto/rand" @@ -9,13 +10,13 @@ import ( "crypto/x509" "crypto/x509/pkix" "encoding/binary" + "errors" + "fmt" "math/big" "time" -) -func certificateHash(c *tls.Config) [32]byte { - return sha256.Sum256(c.Certificates[0].Certificate[0]) -} + "github.com/multiformats/go-multihash" +) func getTLSConf(start, end time.Time) (*tls.Config, error) { cert, priv, err := generateCert(start, end) @@ -61,3 +62,44 @@ func generateCert(start, end time.Time) (*x509.Certificate, *ecdsa.PrivateKey, e } return ca, caPrivateKey, nil } + +func verifyRawCerts(rawCerts [][]byte, certHashes []multihash.DecodedMultihash) error { + if len(rawCerts) < 1 { + return errors.New("no cert") + } + leaf := rawCerts[len(rawCerts)-1] + // The W3C WebTransport specification currently only allows SHA-256 certificates for serverCertificateHashes. + hash := sha256.Sum256(leaf) + var verified bool + for _, h := range certHashes { + if h.Code == multihash.SHA2_256 && bytes.Equal(h.Digest, hash[:]) { + verified = true + break + } + } + if !verified { + digests := make([][]byte, 0, len(certHashes)) + for _, h := range certHashes { + digests = append(digests, h.Digest) + } + return fmt.Errorf("cert hash not found: %#x (expected: %#x)", hash, digests) + } + + cert, err := x509.ParseCertificate(leaf) + if err != nil { + return err + } + // TODO: is this the best (and complete?) way to identify RSA certificates? + switch cert.SignatureAlgorithm { + case x509.SHA1WithRSA, x509.SHA256WithRSA, x509.SHA384WithRSA, x509.SHA512WithRSA, x509.MD2WithRSA, x509.MD5WithRSA: + return errors.New("cert uses RSA") + } + if l := cert.NotAfter.Sub(cert.NotBefore); l > 14*24*time.Hour { + return fmt.Errorf("cert must not be valid for longer than 14 days (NotBefore: %s, NotAfter: %s, Length: %s)", cert.NotBefore, cert.NotAfter, l) + } + now := time.Now() + if now.Before(cert.NotBefore) || now.After(cert.NotAfter) { + return fmt.Errorf("cert not valid (NotBefore: %s, NotAfter: %s)", cert.NotBefore, cert.NotAfter) + } + return nil +} diff --git a/p2p/transport/webtransport/crypto_test.go b/p2p/transport/webtransport/crypto_test.go new file mode 100644 index 0000000000..c69c578b92 --- /dev/null +++ b/p2p/transport/webtransport/crypto_test.go @@ -0,0 +1,128 @@ +package libp2pwebtransport + +import ( + "crypto" + "crypto/ecdsa" + "crypto/elliptic" + "crypto/rand" + "crypto/rsa" + "crypto/sha256" + "crypto/x509" + "crypto/x509/pkix" + "fmt" + "math/big" + mrand "math/rand" + "testing" + "time" + + "github.com/multiformats/go-multihash" + "github.com/stretchr/testify/require" +) + +func sha256Multihash(t *testing.T, b []byte) multihash.DecodedMultihash { + t.Helper() + hash := sha256.Sum256(b) + h, err := multihash.Encode(hash[:], multihash.SHA2_256) + require.NoError(t, err) + dh, err := multihash.Decode(h) + require.NoError(t, err) + return *dh +} + +func generateCertWithKey(t *testing.T, key crypto.PrivateKey, start, end time.Time) *x509.Certificate { + t.Helper() + certTempl := &x509.Certificate{ + SerialNumber: big.NewInt(int64(mrand.Uint64())), + Subject: pkix.Name{}, + NotBefore: start, + NotAfter: end, + IsCA: true, + ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth, x509.ExtKeyUsageServerAuth}, + KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageCertSign, + BasicConstraintsValid: true, + } + caBytes, err := x509.CreateCertificate(rand.Reader, certTempl, certTempl, key.(interface{ Public() crypto.PublicKey }).Public(), key) + require.NoError(t, err) + ca, err := x509.ParseCertificate(caBytes) + require.NoError(t, err) + return ca +} + +func TestCertificateVerification(t *testing.T) { + now := time.Now() + ecdsaKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + require.NoError(t, err) + rsaKey, err := rsa.GenerateKey(rand.Reader, 1024) + require.NoError(t, err) + + t.Run("accepting a valid cert", func(t *testing.T) { + validCert := generateCertWithKey(t, ecdsaKey, now, now.Add(14*24*time.Hour)) + require.NoError(t, verifyRawCerts([][]byte{validCert.Raw}, []multihash.DecodedMultihash{sha256Multihash(t, validCert.Raw)})) + }) + + for _, tc := range [...]struct { + name string + cert *x509.Certificate + errStr string + }{ + { + name: "validitity period too long", + cert: generateCertWithKey(t, ecdsaKey, now, now.Add(15*24*time.Hour)), + errStr: "cert must not be valid for longer than 14 days", + }, + { + name: "uses RSA key", + cert: generateCertWithKey(t, rsaKey, now, now.Add(14*24*time.Hour)), + errStr: "RSA", + }, + { + name: "expired certificate", + cert: generateCertWithKey(t, ecdsaKey, now.Add(-14*24*time.Hour), now), + errStr: "cert not valid", + }, + { + name: "not yet valid", + cert: generateCertWithKey(t, ecdsaKey, now.Add(time.Hour), now.Add(time.Hour+14*24*time.Hour)), + errStr: "cert not valid", + }, + } { + tc := tc + t.Run(fmt.Sprintf("rejecting invalid certificates: %s", tc.name), func(t *testing.T) { + err := verifyRawCerts([][]byte{tc.cert.Raw}, []multihash.DecodedMultihash{sha256Multihash(t, tc.cert.Raw)}) + require.Error(t, err) + require.Contains(t, err.Error(), tc.errStr) + }) + } + + for _, tc := range [...]struct { + name string + certs [][]byte + hashes []multihash.DecodedMultihash + errStr string + }{ + { + name: "no certificates", + hashes: []multihash.DecodedMultihash{sha256Multihash(t, []byte("foobar"))}, + errStr: "no cert", + }, + { + name: "certificate not parseable", + certs: [][]byte{[]byte("foobar")}, + hashes: []multihash.DecodedMultihash{sha256Multihash(t, []byte("foobar"))}, + errStr: "x509: malformed certificate", + }, + { + name: "hash mismatch", + certs: [][]byte{generateCertWithKey(t, ecdsaKey, now, now.Add(15*24*time.Hour)).Raw}, + hashes: []multihash.DecodedMultihash{sha256Multihash(t, []byte("foobar"))}, + errStr: "cert hash not found", + }, + } { + tc := tc + t.Run(fmt.Sprintf("rejecting invalid certificates: %s", tc.name), func(t *testing.T) { + err := verifyRawCerts(tc.certs, tc.hashes) + require.Error(t, err) + require.Contains(t, err.Error(), tc.errStr) + }) + } +} diff --git a/p2p/transport/webtransport/transport.go b/p2p/transport/webtransport/transport.go index c20c4c47ed..68938565cb 100644 --- a/p2p/transport/webtransport/transport.go +++ b/p2p/transport/webtransport/transport.go @@ -3,6 +3,7 @@ package libp2pwebtransport import ( "context" "crypto/tls" + "crypto/x509" "fmt" "io" "sync" @@ -55,8 +56,6 @@ type transport struct { privKey ic.PrivKey pid peer.ID - dialer webtransport.Dialer - rcmgr network.ResourceManager gater connmgr.ConnectionGater @@ -80,11 +79,6 @@ func New(key ic.PrivKey, gater connmgr.ConnectionGater, rcmgr network.ResourceMa privKey: key, rcmgr: rcmgr, gater: gater, - dialer: webtransport.Dialer{ - RoundTripper: &http3.RoundTripper{ - TLSClientConfig: &tls.Config{InsecureSkipVerify: true}, // TODO: verify certificate, - }, - }, } noise, err := noise.New(key, noise.WithEarlyDataHandler(t.checkEarlyData)) if err != nil { @@ -115,7 +109,7 @@ func (t *transport) Dial(ctx context.Context, raddr ma.Multiaddr, p peer.ID) (tp return nil, err } - sess, err := t.dial(ctx, addr) + sess, err := t.dial(ctx, addr, certHashes) if err != nil { scope.Done() return nil, err @@ -135,9 +129,19 @@ func (t *transport) Dial(ctx context.Context, raddr ma.Multiaddr, p peer.ID) (tp return newConn(t, sess, sconn, scope), nil } -func (t *transport) dial(ctx context.Context, addr string) (*webtransport.Session, error) { +func (t *transport) dial(ctx context.Context, addr string, certHashes []multihash.DecodedMultihash) (*webtransport.Session, error) { url := fmt.Sprintf("https://%s%s", addr, webtransportHTTPEndpoint) - rsp, sess, err := t.dialer.Dial(ctx, url, nil) + dialer := webtransport.Dialer{ + RoundTripper: &http3.RoundTripper{ + TLSClientConfig: &tls.Config{ + InsecureSkipVerify: true, // this is not insecure. We verify the certificate ourselves. + VerifyPeerCertificate: func(rawCerts [][]byte, _ [][]*x509.Certificate) error { + return verifyRawCerts(rawCerts, certHashes) + }, + }, + }, + } + rsp, sess, err := dialer.Dial(ctx, url, nil) if err != nil { return nil, err } diff --git a/p2p/transport/webtransport/transport_test.go b/p2p/transport/webtransport/transport_test.go index 8540cabd9b..3bc6d27efc 100644 --- a/p2p/transport/webtransport/transport_test.go +++ b/p2p/transport/webtransport/transport_test.go @@ -145,6 +145,7 @@ func TestHashVerification(t *testing.T) { _, err := tr2.Dial(context.Background(), addr, serverID) require.Error(t, err) + require.Contains(t, err.Error(), "CRYPTO_ERROR (0x12a): cert hash not found") }) t.Run("fails when adding a wrong hash", func(t *testing.T) { From f0dbd3e7e740b4466b38073e59e13f9b523b6e33 Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Sun, 10 Jul 2022 13:01:21 +0000 Subject: [PATCH 34/44] fix flaky resource manager test (#19) --- p2p/transport/webtransport/transport_test.go | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/p2p/transport/webtransport/transport_test.go b/p2p/transport/webtransport/transport_test.go index 3bc6d27efc..d85fcb828f 100644 --- a/p2p/transport/webtransport/transport_test.go +++ b/p2p/transport/webtransport/transport_test.go @@ -294,23 +294,29 @@ func TestResourceManagerListening(t *testing.T) { require.NoError(t, err) defer ln.Close() + serverDone := make(chan struct{}) scope := mocknetwork.NewMockConnManagementScope(ctrl) rcmgr.EXPECT().OpenConnection(network.DirInbound, false, gomock.Any()).Return(scope, nil) scope.EXPECT().SetPeer(clientID).Return(errors.New("denied")) - scope.EXPECT().Done() + scope.EXPECT().Done().Do(func() { close(serverDone) }) // The handshake will complete, but the server will immediately close the connection. conn, err := cl.Dial(context.Background(), ln.Multiaddr(), serverID) require.NoError(t, err) defer conn.Close() - done := make(chan struct{}) + clientDone := make(chan struct{}) go func() { - defer close(done) + defer close(clientDone) _, err = conn.AcceptStream() require.Error(t, err) }() select { - case <-done: + case <-clientDone: + case <-time.After(5 * time.Second): + t.Fatal("timeout") + } + select { + case <-serverDone: case <-time.After(5 * time.Second): t.Fatal("timeout") } From d626e806968e1014a268d13d00a89290a2b17e58 Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Sun, 10 Jul 2022 20:49:27 +0000 Subject: [PATCH 35/44] move connection interface to conn.go --- p2p/transport/webtransport/conn.go | 21 +++++++++++++++++++++ p2p/transport/webtransport/transport.go | 19 ------------------- 2 files changed, 21 insertions(+), 19 deletions(-) diff --git a/p2p/transport/webtransport/conn.go b/p2p/transport/webtransport/conn.go index a8f73b079d..409cea3383 100644 --- a/p2p/transport/webtransport/conn.go +++ b/p2p/transport/webtransport/conn.go @@ -2,12 +2,33 @@ package libp2pwebtransport import ( "context" + "github.com/libp2p/go-libp2p-core/network" tpt "github.com/libp2p/go-libp2p-core/transport" "github.com/marten-seemann/webtransport-go" + ma "github.com/multiformats/go-multiaddr" ) +type connSecurityMultiaddrs interface { + network.ConnMultiaddrs + network.ConnSecurity +} + +type connSecurityMultiaddrsImpl struct { + network.ConnSecurity + network.ConnMultiaddrs +} + +type connMultiaddrs struct { + local, remote ma.Multiaddr +} + +var _ network.ConnMultiaddrs = &connMultiaddrs{} + +func (c *connMultiaddrs) LocalMultiaddr() ma.Multiaddr { return c.local } +func (c *connMultiaddrs) RemoteMultiaddr() ma.Multiaddr { return c.remote } + type conn struct { connSecurityMultiaddrs diff --git a/p2p/transport/webtransport/transport.go b/p2p/transport/webtransport/transport.go index 68938565cb..5438e259b7 100644 --- a/p2p/transport/webtransport/transport.go +++ b/p2p/transport/webtransport/transport.go @@ -33,25 +33,6 @@ const webtransportHTTPEndpoint = "/.well-known/libp2p-webtransport" const certValidity = 14 * 24 * time.Hour -type connSecurityMultiaddrs interface { - network.ConnMultiaddrs - network.ConnSecurity -} - -type connSecurityMultiaddrsImpl struct { - network.ConnSecurity - network.ConnMultiaddrs -} - -type connMultiaddrs struct { - local, remote ma.Multiaddr -} - -var _ network.ConnMultiaddrs = &connMultiaddrs{} - -func (c *connMultiaddrs) LocalMultiaddr() ma.Multiaddr { return c.local } -func (c *connMultiaddrs) RemoteMultiaddr() ma.Multiaddr { return c.remote } - type transport struct { privKey ic.PrivKey pid peer.ID From ebcb51309fc6825a6f8cde3fb2cb3df316e7822e Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Sun, 10 Jul 2022 21:19:25 +0000 Subject: [PATCH 36/44] use a mock clock in cert manager tests (#20) --- p2p/transport/webtransport/cert_manager.go | 17 +++++++------ .../webtransport/cert_manager_test.go | 25 +++++++++++-------- p2p/transport/webtransport/transport.go | 21 ++++++++++++++-- 3 files changed, 43 insertions(+), 20 deletions(-) diff --git a/p2p/transport/webtransport/cert_manager.go b/p2p/transport/webtransport/cert_manager.go index f2a0afd04e..8ca5b85913 100644 --- a/p2p/transport/webtransport/cert_manager.go +++ b/p2p/transport/webtransport/cert_manager.go @@ -9,6 +9,7 @@ import ( "sync" "time" + "github.com/benbjohnson/clock" ma "github.com/multiformats/go-multiaddr" "github.com/multiformats/go-multibase" "github.com/multiformats/go-multihash" @@ -40,6 +41,7 @@ func newCertConfig(start, end time.Time, conf *tls.Config) (*certConfig, error) // We continue to remember the hash of (1) for validation during the Noise handshake for another 4 days, // as the client might be connecting with a cached address. type certManager struct { + clock clock.Clock ctx context.Context ctxCancel context.CancelFunc refCount sync.WaitGroup @@ -53,18 +55,22 @@ type certManager struct { addrComp ma.Multiaddr } -func newCertManager(certValidity time.Duration) (*certManager, error) { +func newCertManager(clock clock.Clock, certValidity time.Duration) (*certManager, error) { m := &certManager{ + clock: clock, certValidity: certValidity, } m.ctx, m.ctxCancel = context.WithCancel(context.Background()) if err := m.init(); err != nil { return nil, err } + + t := m.clock.Ticker(m.certValidity * 4 / 9) // make sure we're a bit faster than 1/2 m.refCount.Add(1) go func() { defer m.refCount.Done() - if err := m.background(); err != nil { + defer t.Stop() + if err := m.background(t); err != nil { log.Fatal(err) } }() @@ -72,7 +78,7 @@ func newCertManager(certValidity time.Duration) (*certManager, error) { } func (m *certManager) init() error { - start := time.Now() + start := m.clock.Now() end := start.Add(m.certValidity) tlsConf, err := getTLSConf(start, end) if err != nil { @@ -86,10 +92,7 @@ func (m *certManager) init() error { return m.cacheAddrComponent() } -func (m *certManager) background() error { - t := time.NewTicker(m.certValidity * 4 / 9) // make sure we're a bit faster than 1/2 - defer t.Stop() - +func (m *certManager) background(t *clock.Ticker) error { for { select { case <-m.ctx.Done(): diff --git a/p2p/transport/webtransport/cert_manager_test.go b/p2p/transport/webtransport/cert_manager_test.go index cb55a56ddc..84cb93510a 100644 --- a/p2p/transport/webtransport/cert_manager_test.go +++ b/p2p/transport/webtransport/cert_manager_test.go @@ -3,10 +3,11 @@ package libp2pwebtransport import ( "crypto/sha256" "crypto/tls" - "os" + "fmt" "testing" "time" + "github.com/benbjohnson/clock" ma "github.com/multiformats/go-multiaddr" "github.com/multiformats/go-multibase" "github.com/multiformats/go-multihash" @@ -37,15 +38,17 @@ func certHashFromComponent(t *testing.T, comp ma.Component) []byte { } func TestInitialCert(t *testing.T) { - m, err := newCertManager(certValidity) + cl := clock.NewMock() + cl.Add(1234567 * time.Hour) + m, err := newCertManager(cl, certValidity) require.NoError(t, err) defer m.Close() conf := m.GetConfig() require.Len(t, conf.Certificates, 1) cert := conf.Certificates[0] - require.WithinDuration(t, time.Now(), cert.Leaf.NotBefore, time.Second) - require.WithinDuration(t, time.Now().Add(certValidity), cert.Leaf.NotAfter, time.Second) + require.Equal(t, cl.Now().UTC(), cert.Leaf.NotBefore) + require.Equal(t, cl.Now().Add(certValidity).UTC(), cert.Leaf.NotAfter) addr := m.AddrComponent() components := splitMultiaddr(addr) require.Len(t, components, 1) @@ -55,18 +58,17 @@ func TestInitialCert(t *testing.T) { } func TestCertRenewal(t *testing.T) { - var certValidity = 300 * time.Millisecond - if os.Getenv("CI") != "" { - certValidity = 2 * time.Second - } - m, err := newCertManager(certValidity) + cl := clock.NewMock() + m, err := newCertManager(cl, certValidity) require.NoError(t, err) defer m.Close() firstConf := m.GetConfig() require.Len(t, splitMultiaddr(m.AddrComponent()), 1) // wait for a new certificate to be generated - require.Eventually(t, func() bool { return len(splitMultiaddr(m.AddrComponent())) > 1 }, certValidity/2, 10*time.Millisecond) + fmt.Println("add time") + cl.Add(certValidity / 2) + require.Eventually(t, func() bool { return len(splitMultiaddr(m.AddrComponent())) > 1 }, 200*time.Millisecond, 10*time.Millisecond) // the actual config used should still be the same, we're just advertising the hash of the next config components := splitMultiaddr(m.AddrComponent()) require.Len(t, components, 2) @@ -74,7 +76,8 @@ func TestCertRenewal(t *testing.T) { require.Equal(t, ma.P_CERTHASH, c.Protocol().Code) } require.Equal(t, firstConf, m.GetConfig()) - require.Eventually(t, func() bool { return m.GetConfig() != firstConf }, certValidity/2, 10*time.Millisecond) + cl.Add(certValidity / 2) + require.Eventually(t, func() bool { return m.GetConfig() != firstConf }, 200*time.Millisecond, 10*time.Millisecond) newConf := m.GetConfig() // check that the new config now matches the second component hash := certificateHashFromTLSConfig(newConf) diff --git a/p2p/transport/webtransport/transport.go b/p2p/transport/webtransport/transport.go index 5438e259b7..f27fcfdb99 100644 --- a/p2p/transport/webtransport/transport.go +++ b/p2p/transport/webtransport/transport.go @@ -19,6 +19,7 @@ import ( noise "github.com/libp2p/go-libp2p-noise" + "github.com/benbjohnson/clock" logging "github.com/ipfs/go-log/v2" "github.com/lucas-clemente/quic-go/http3" "github.com/marten-seemann/webtransport-go" @@ -33,9 +34,19 @@ const webtransportHTTPEndpoint = "/.well-known/libp2p-webtransport" const certValidity = 14 * 24 * time.Hour +type Option func(*transport) error + +func WithClock(cl clock.Clock) Option { + return func(t *transport) error { + t.clock = cl + return nil + } +} + type transport struct { privKey ic.PrivKey pid peer.ID + clock clock.Clock rcmgr network.ResourceManager gater connmgr.ConnectionGater @@ -50,7 +61,7 @@ type transport struct { var _ tpt.Transport = &transport{} var _ io.Closer = &transport{} -func New(key ic.PrivKey, gater connmgr.ConnectionGater, rcmgr network.ResourceManager) (tpt.Transport, error) { +func New(key ic.PrivKey, gater connmgr.ConnectionGater, rcmgr network.ResourceManager, opts ...Option) (tpt.Transport, error) { id, err := peer.IDFromPrivateKey(key) if err != nil { return nil, err @@ -60,6 +71,12 @@ func New(key ic.PrivKey, gater connmgr.ConnectionGater, rcmgr network.ResourceMa privKey: key, rcmgr: rcmgr, gater: gater, + clock: clock.New(), + } + for _, opt := range opts { + if err := opt(t); err != nil { + return nil, err + } } noise, err := noise.New(key, noise.WithEarlyDataHandler(t.checkEarlyData)) if err != nil { @@ -208,7 +225,7 @@ func (t *transport) Listen(laddr ma.Multiaddr) (tpt.Listener, error) { return nil, fmt.Errorf("cannot listen on non-WebTransport addr: %s", laddr) } t.listenOnce.Do(func() { - t.certManager, t.listenOnceErr = newCertManager(certValidity) + t.certManager, t.listenOnceErr = newCertManager(t.clock, certValidity) }) if t.listenOnceErr != nil { return nil, t.listenOnceErr From 9f2e830b65500721934073fb31c21e44c1f1fcce Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Sun, 10 Jul 2022 21:22:45 +0000 Subject: [PATCH 37/44] remove member variable for certificate validity from cert manager --- p2p/transport/webtransport/cert_manager.go | 15 +++++---------- p2p/transport/webtransport/cert_manager_test.go | 6 ++---- p2p/transport/webtransport/transport.go | 2 +- 3 files changed, 8 insertions(+), 15 deletions(-) diff --git a/p2p/transport/webtransport/cert_manager.go b/p2p/transport/webtransport/cert_manager.go index 8ca5b85913..a797b55085 100644 --- a/p2p/transport/webtransport/cert_manager.go +++ b/p2p/transport/webtransport/cert_manager.go @@ -46,8 +46,6 @@ type certManager struct { ctxCancel context.CancelFunc refCount sync.WaitGroup - certValidity time.Duration // so we can set it in tests - mx sync.Mutex lastConfig *certConfig // initially nil currentConfig *certConfig @@ -55,17 +53,14 @@ type certManager struct { addrComp ma.Multiaddr } -func newCertManager(clock clock.Clock, certValidity time.Duration) (*certManager, error) { - m := &certManager{ - clock: clock, - certValidity: certValidity, - } +func newCertManager(clock clock.Clock) (*certManager, error) { + m := &certManager{clock: clock} m.ctx, m.ctxCancel = context.WithCancel(context.Background()) if err := m.init(); err != nil { return nil, err } - t := m.clock.Ticker(m.certValidity * 4 / 9) // make sure we're a bit faster than 1/2 + t := m.clock.Ticker(certValidity * 4 / 9) // make sure we're a bit faster than 1/2 m.refCount.Add(1) go func() { defer m.refCount.Done() @@ -79,7 +74,7 @@ func newCertManager(clock clock.Clock, certValidity time.Duration) (*certManager func (m *certManager) init() error { start := m.clock.Now() - end := start.Add(m.certValidity) + end := start.Add(certValidity) tlsConf, err := getTLSConf(start, end) if err != nil { return err @@ -98,7 +93,7 @@ func (m *certManager) background(t *clock.Ticker) error { case <-m.ctx.Done(): return nil case start := <-t.C: - end := start.Add(m.certValidity) + end := start.Add(certValidity) tlsConf, err := getTLSConf(start, end) if err != nil { return err diff --git a/p2p/transport/webtransport/cert_manager_test.go b/p2p/transport/webtransport/cert_manager_test.go index 84cb93510a..69cd7163da 100644 --- a/p2p/transport/webtransport/cert_manager_test.go +++ b/p2p/transport/webtransport/cert_manager_test.go @@ -3,7 +3,6 @@ package libp2pwebtransport import ( "crypto/sha256" "crypto/tls" - "fmt" "testing" "time" @@ -40,7 +39,7 @@ func certHashFromComponent(t *testing.T, comp ma.Component) []byte { func TestInitialCert(t *testing.T) { cl := clock.NewMock() cl.Add(1234567 * time.Hour) - m, err := newCertManager(cl, certValidity) + m, err := newCertManager(cl) require.NoError(t, err) defer m.Close() @@ -59,14 +58,13 @@ func TestInitialCert(t *testing.T) { func TestCertRenewal(t *testing.T) { cl := clock.NewMock() - m, err := newCertManager(cl, certValidity) + m, err := newCertManager(cl) require.NoError(t, err) defer m.Close() firstConf := m.GetConfig() require.Len(t, splitMultiaddr(m.AddrComponent()), 1) // wait for a new certificate to be generated - fmt.Println("add time") cl.Add(certValidity / 2) require.Eventually(t, func() bool { return len(splitMultiaddr(m.AddrComponent())) > 1 }, 200*time.Millisecond, 10*time.Millisecond) // the actual config used should still be the same, we're just advertising the hash of the next config diff --git a/p2p/transport/webtransport/transport.go b/p2p/transport/webtransport/transport.go index f27fcfdb99..c32d1059a3 100644 --- a/p2p/transport/webtransport/transport.go +++ b/p2p/transport/webtransport/transport.go @@ -225,7 +225,7 @@ func (t *transport) Listen(laddr ma.Multiaddr) (tpt.Listener, error) { return nil, fmt.Errorf("cannot listen on non-WebTransport addr: %s", laddr) } t.listenOnce.Do(func() { - t.certManager, t.listenOnceErr = newCertManager(t.clock, certValidity) + t.certManager, t.listenOnceErr = newCertManager(t.clock) }) if t.listenOnceErr != nil { return nil, t.listenOnceErr From ff5aa304a0aca38967f4a2378bd233fb360f8a64 Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Tue, 12 Jul 2022 10:43:19 +0000 Subject: [PATCH 38/44] simplify certificate generation --- p2p/transport/webtransport/cert_manager.go | 25 ++++++++-------------- 1 file changed, 9 insertions(+), 16 deletions(-) diff --git a/p2p/transport/webtransport/cert_manager.go b/p2p/transport/webtransport/cert_manager.go index a797b55085..519a89d88e 100644 --- a/p2p/transport/webtransport/cert_manager.go +++ b/p2p/transport/webtransport/cert_manager.go @@ -16,15 +16,16 @@ import ( ) type certConfig struct { - start, end time.Time - tlsConf *tls.Config - sha256 [32]byte // cached from the tlsConf + tlsConf *tls.Config + sha256 [32]byte // cached from the tlsConf } -func newCertConfig(start, end time.Time, conf *tls.Config) (*certConfig, error) { +func newCertConfig(start, end time.Time) (*certConfig, error) { + conf, err := getTLSConf(start, end) + if err != nil { + return nil, err + } return &certConfig{ - start: start, - end: end, tlsConf: conf, sha256: sha256.Sum256(conf.Certificates[0].Leaf.Raw), }, nil @@ -75,11 +76,7 @@ func newCertManager(clock clock.Clock) (*certManager, error) { func (m *certManager) init() error { start := m.clock.Now() end := start.Add(certValidity) - tlsConf, err := getTLSConf(start, end) - if err != nil { - return err - } - cc, err := newCertConfig(start, end, tlsConf) + cc, err := newCertConfig(start, end) if err != nil { return err } @@ -94,11 +91,7 @@ func (m *certManager) background(t *clock.Ticker) error { return nil case start := <-t.C: end := start.Add(certValidity) - tlsConf, err := getTLSConf(start, end) - if err != nil { - return err - } - cc, err := newCertConfig(start, end, tlsConf) + cc, err := newCertConfig(start, end) if err != nil { return err } From 2823159a99d45518a94db48f9942fa54005f2797 Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Sat, 16 Jul 2022 10:22:10 +0000 Subject: [PATCH 39/44] optimize expiry periods of certificates (#21) --- p2p/transport/webtransport/cert_manager.go | 106 ++++++++++-------- .../webtransport/cert_manager_test.go | 53 ++++++--- p2p/transport/webtransport/transport_test.go | 13 ++- 3 files changed, 105 insertions(+), 67 deletions(-) diff --git a/p2p/transport/webtransport/cert_manager.go b/p2p/transport/webtransport/cert_manager.go index 519a89d88e..c46a60b666 100644 --- a/p2p/transport/webtransport/cert_manager.go +++ b/p2p/transport/webtransport/cert_manager.go @@ -15,11 +15,19 @@ import ( "github.com/multiformats/go-multihash" ) +// Allow for a bit of clock skew. +// When we generate a certificate, the NotBefore time is set to clockSkewAllowance before the current time. +// Similarly, we stop using a certificate one clockSkewAllowance before its expiry time. +const clockSkewAllowance = time.Hour + type certConfig struct { tlsConf *tls.Config sha256 [32]byte // cached from the tlsConf } +func (c *certConfig) Start() time.Time { return c.tlsConf.Certificates[0].Leaf.NotBefore } +func (c *certConfig) End() time.Time { return c.tlsConf.Certificates[0].Leaf.NotAfter } + func newCertConfig(start, end time.Time) (*certConfig, error) { conf, err := getTLSConf(start, end) if err != nil { @@ -32,22 +40,17 @@ func newCertConfig(start, end time.Time) (*certConfig, error) { } // Certificate renewal logic: -// 0. To simplify the math, assume the certificate is valid for 10 days (in real life: 14 days). -// 1. On startup, we generate the first certificate (1). -// 2. After 4 days, we generate a second certificate (2). -// We don't use that certificate yet, but we advertise the hashes of (1) and (2). -// That allows clients to connect to us using addresses that are 4 days old. -// 3. After another 4 days, we now actually start using (2). -// We also generate a third certificate (3), and start advertising the hashes of (2) and (3). -// We continue to remember the hash of (1) for validation during the Noise handshake for another 4 days, -// as the client might be connecting with a cached address. +// 1. On startup, we generate one cert that is valid from now (-1h, to allow for clock skew), and another +// cert that is valid from the expiry date of the first certificate (again, with allowance for clock skew). +// 2. Once we reach 1h before expiry of the first certificate, we switch over to the second certificate. +// At the same time, we stop advertising the certhash of the first cert and generate the next cert. type certManager struct { clock clock.Clock ctx context.Context ctxCancel context.CancelFunc refCount sync.WaitGroup - mx sync.Mutex + mx sync.RWMutex lastConfig *certConfig // initially nil currentConfig *certConfig nextConfig *certConfig // nil until we have passed half the certValidity of the current config @@ -61,64 +64,71 @@ func newCertManager(clock clock.Clock) (*certManager, error) { return nil, err } - t := m.clock.Ticker(certValidity * 4 / 9) // make sure we're a bit faster than 1/2 - m.refCount.Add(1) - go func() { - defer m.refCount.Done() - defer t.Stop() - if err := m.background(t); err != nil { - log.Fatal(err) - } - }() + m.background() return m, nil } func (m *certManager) init() error { - start := m.clock.Now() - end := start.Add(certValidity) - cc, err := newCertConfig(start, end) + start := m.clock.Now().Add(-clockSkewAllowance) + var err error + m.nextConfig, err = newCertConfig(start, start.Add(certValidity)) if err != nil { return err } - m.currentConfig = cc + return m.rollConfig() +} + +func (m *certManager) rollConfig() error { + // We stop using the current certificate clockSkewAllowance before its expiry time. + // At this point, the next certificate needs to be valid for one clockSkewAllowance. + nextStart := m.nextConfig.End().Add(-2 * clockSkewAllowance) + c, err := newCertConfig(nextStart, nextStart.Add(certValidity)) + if err != nil { + return err + } + m.lastConfig = m.currentConfig + m.currentConfig = m.nextConfig + m.nextConfig = c return m.cacheAddrComponent() } -func (m *certManager) background(t *clock.Ticker) error { - for { - select { - case <-m.ctx.Done(): - return nil - case start := <-t.C: - end := start.Add(certValidity) - cc, err := newCertConfig(start, end) - if err != nil { - return err - } - m.mx.Lock() - if m.nextConfig != nil { - m.lastConfig = m.currentConfig - m.currentConfig = m.nextConfig - } - m.nextConfig = cc - if err := m.cacheAddrComponent(); err != nil { +func (m *certManager) background() { + d := m.currentConfig.End().Add(-clockSkewAllowance).Sub(m.clock.Now()) + log.Debugw("setting timer", "duration", d.String()) + t := m.clock.Timer(d) + m.refCount.Add(1) + + go func() { + defer m.refCount.Done() + defer t.Stop() + + for { + select { + case <-m.ctx.Done(): + return + case now := <-t.C: + m.mx.Lock() + if err := m.rollConfig(); err != nil { + log.Errorw("rolling config failed", "error", err) + } + d := m.currentConfig.End().Add(-clockSkewAllowance).Sub(now) + log.Debugw("rolling certificates", "next", d.String()) + t.Reset(d) m.mx.Unlock() - return err } - m.mx.Unlock() } - } + }() } func (m *certManager) GetConfig() *tls.Config { - m.mx.Lock() - defer m.mx.Unlock() + m.mx.RLock() + defer m.mx.RUnlock() return m.currentConfig.tlsConf } func (m *certManager) AddrComponent() ma.Multiaddr { - m.mx.Lock() - defer m.mx.Unlock() + m.mx.RLock() + defer m.mx.RUnlock() return m.addrComp } diff --git a/p2p/transport/webtransport/cert_manager_test.go b/p2p/transport/webtransport/cert_manager_test.go index 69cd7163da..3f2328fbb7 100644 --- a/p2p/transport/webtransport/cert_manager_test.go +++ b/p2p/transport/webtransport/cert_manager_test.go @@ -46,14 +46,15 @@ func TestInitialCert(t *testing.T) { conf := m.GetConfig() require.Len(t, conf.Certificates, 1) cert := conf.Certificates[0] - require.Equal(t, cl.Now().UTC(), cert.Leaf.NotBefore) - require.Equal(t, cl.Now().Add(certValidity).UTC(), cert.Leaf.NotAfter) + require.Equal(t, cl.Now().Add(-clockSkewAllowance).UTC(), cert.Leaf.NotBefore) + require.Equal(t, cert.Leaf.NotBefore.Add(certValidity), cert.Leaf.NotAfter) addr := m.AddrComponent() components := splitMultiaddr(addr) - require.Len(t, components, 1) + require.Len(t, components, 2) require.Equal(t, ma.P_CERTHASH, components[0].Protocol().Code) hash := certificateHashFromTLSConfig(conf) require.Equal(t, hash[:], certHashFromComponent(t, components[0])) + require.Equal(t, ma.P_CERTHASH, components[1].Protocol().Code) } func TestCertRenewal(t *testing.T) { @@ -63,21 +64,39 @@ func TestCertRenewal(t *testing.T) { defer m.Close() firstConf := m.GetConfig() - require.Len(t, splitMultiaddr(m.AddrComponent()), 1) + first := splitMultiaddr(m.AddrComponent()) + require.Len(t, first, 2) + require.NotEqual(t, first[0].Value(), first[1].Value(), "the hashes should differ") // wait for a new certificate to be generated - cl.Add(certValidity / 2) - require.Eventually(t, func() bool { return len(splitMultiaddr(m.AddrComponent())) > 1 }, 200*time.Millisecond, 10*time.Millisecond) - // the actual config used should still be the same, we're just advertising the hash of the next config - components := splitMultiaddr(m.AddrComponent()) - require.Len(t, components, 2) - for _, c := range components { + cl.Add(certValidity - 2*clockSkewAllowance - time.Second) + require.Never(t, func() bool { + for i, c := range splitMultiaddr(m.AddrComponent()) { + if c.Value() != first[i].Value() { + return true + } + } + return false + }, 100*time.Millisecond, 10*time.Millisecond) + cl.Add(2 * time.Second) + require.Eventually(t, func() bool { return m.GetConfig() != firstConf }, 200*time.Millisecond, 10*time.Millisecond) + secondConf := m.GetConfig() + + second := splitMultiaddr(m.AddrComponent()) + require.Len(t, second, 2) + for _, c := range second { require.Equal(t, ma.P_CERTHASH, c.Protocol().Code) } - require.Equal(t, firstConf, m.GetConfig()) - cl.Add(certValidity / 2) - require.Eventually(t, func() bool { return m.GetConfig() != firstConf }, 200*time.Millisecond, 10*time.Millisecond) - newConf := m.GetConfig() - // check that the new config now matches the second component - hash := certificateHashFromTLSConfig(newConf) - require.Equal(t, hash[:], certHashFromComponent(t, components[1])) + // check that the 2nd certificate from the beginning was rolled over to be the 1st certificate + require.Equal(t, first[1].Value(), second[0].Value()) + require.NotEqual(t, first[0].Value(), second[1].Value()) + + cl.Add(certValidity - 2*clockSkewAllowance + time.Second) + require.Eventually(t, func() bool { return m.GetConfig() != secondConf }, 200*time.Millisecond, 10*time.Millisecond) + third := splitMultiaddr(m.AddrComponent()) + require.Len(t, third, 2) + for _, c := range third { + require.Equal(t, ma.P_CERTHASH, c.Protocol().Code) + } + // check that the 2nd certificate from the beginning was rolled over to be the 1st certificate + require.Equal(t, second[1].Value(), third[0].Value()) } diff --git a/p2p/transport/webtransport/transport_test.go b/p2p/transport/webtransport/transport_test.go index d85fcb828f..61cd245138 100644 --- a/p2p/transport/webtransport/transport_test.go +++ b/p2p/transport/webtransport/transport_test.go @@ -140,7 +140,16 @@ func TestHashVerification(t *testing.T) { t.Run("fails using only a wrong hash", func(t *testing.T) { // replace the certificate hash in the multiaddr with a fake hash - addr, _ := ma.SplitLast(ln.Multiaddr()) + addr := ln.Multiaddr() + // strip off all certhash components + for { + a, comp := ma.SplitLast(addr) + if comp.Protocol().Code != ma.P_CERTHASH { + break + } + addr = a + } + addr = addr.Encapsulate(foobarHash) _, err := tr2.Dial(context.Background(), addr, serverID) @@ -224,7 +233,7 @@ func TestListenerAddrs(t *testing.T) { ln2, err := tr.Listen(ma.StringCast("/ip4/127.0.0.1/udp/0/quic/webtransport")) require.NoError(t, err) hashes1 := extractCertHashes(ln1.Multiaddr()) - require.Len(t, hashes1, 1) + require.Len(t, hashes1, 2) hashes2 := extractCertHashes(ln2.Multiaddr()) require.Equal(t, hashes1, hashes2) } From d74921df0a4be3b73241b26236c4ff6b36b0f5d0 Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Sat, 16 Jul 2022 13:55:57 +0000 Subject: [PATCH 40/44] make it possible to use a custom tls.Config for listening and dialing (#22) --- p2p/transport/webtransport/listener.go | 49 ++++---- p2p/transport/webtransport/transport.go | 68 +++++++--- p2p/transport/webtransport/transport_test.go | 123 ++++++++++++++++--- 3 files changed, 184 insertions(+), 56 deletions(-) diff --git a/p2p/transport/webtransport/listener.go b/p2p/transport/webtransport/listener.go index bb762f7d7a..872e65804b 100644 --- a/p2p/transport/webtransport/listener.go +++ b/p2p/transport/webtransport/listener.go @@ -27,11 +27,13 @@ const queueLen = 16 const handshakeTimeout = 10 * time.Second type listener struct { - transport tpt.Transport - noise *noise.Transport - certManager *certManager - rcmgr network.ResourceManager - gater connmgr.ConnectionGater + transport tpt.Transport + noise *noise.Transport + certManager *certManager + staticTLSConf *tls.Config + + rcmgr network.ResourceManager + gater connmgr.ConnectionGater server webtransport.Server @@ -48,7 +50,7 @@ type listener struct { var _ tpt.Listener = &listener{} -func newListener(laddr ma.Multiaddr, transport tpt.Transport, noise *noise.Transport, certManager *certManager, gater connmgr.ConnectionGater, rcmgr network.ResourceManager) (tpt.Listener, error) { +func newListener(laddr ma.Multiaddr, transport tpt.Transport, noise *noise.Transport, certManager *certManager, tlsConf *tls.Config, gater connmgr.ConnectionGater, rcmgr network.ResourceManager) (tpt.Listener, error) { network, addr, err := manet.DialArgs(laddr) if err != nil { return nil, err @@ -65,23 +67,23 @@ func newListener(laddr ma.Multiaddr, transport tpt.Transport, noise *noise.Trans if err != nil { return nil, err } + if tlsConf == nil { + tlsConf = &tls.Config{GetConfigForClient: func(*tls.ClientHelloInfo) (*tls.Config, error) { + return certManager.GetConfig(), nil + }} + } ln := &listener{ - transport: transport, - noise: noise, - certManager: certManager, - rcmgr: rcmgr, - gater: gater, - queue: make(chan tpt.CapableConn, queueLen), - serverClosed: make(chan struct{}), - addr: udpConn.LocalAddr(), - multiaddr: localMultiaddr, - server: webtransport.Server{ - H3: http3.Server{ - TLSConfig: &tls.Config{GetConfigForClient: func(*tls.ClientHelloInfo) (*tls.Config, error) { - return certManager.GetConfig(), nil - }}, - }, - }, + transport: transport, + noise: noise, + certManager: certManager, + staticTLSConf: tlsConf, + rcmgr: rcmgr, + gater: gater, + queue: make(chan tpt.CapableConn, queueLen), + serverClosed: make(chan struct{}), + addr: udpConn.LocalAddr(), + multiaddr: localMultiaddr, + server: webtransport.Server{H3: http3.Server{TLSConfig: tlsConf}}, } ln.ctx, ln.ctxCancel = context.WithCancel(context.Background()) mux := http.NewServeMux() @@ -198,6 +200,9 @@ func (l *listener) Addr() net.Addr { } func (l *listener) Multiaddr() ma.Multiaddr { + if l.certManager == nil { + return l.multiaddr + } return l.multiaddr.Encapsulate(l.certManager.AddrComponent()) } diff --git a/p2p/transport/webtransport/transport.go b/p2p/transport/webtransport/transport.go index c32d1059a3..e027e1fe8f 100644 --- a/p2p/transport/webtransport/transport.go +++ b/p2p/transport/webtransport/transport.go @@ -4,6 +4,7 @@ import ( "context" "crypto/tls" "crypto/x509" + "errors" "fmt" "io" "sync" @@ -43,6 +44,27 @@ func WithClock(cl clock.Clock) Option { } } +// WithTLSConfig sets a tls.Config used for listening. +// When used, the certificate from that config will be used, and no /certhash will be added to the listener's multiaddr. +// This is most useful when running a listener that has a valid (CA-signed) certificate. +func WithTLSConfig(c *tls.Config) Option { + return func(t *transport) error { + t.staticTLSConf = c + return nil + } +} + +// WithTLSClientConfig sets a custom tls.Config used for dialing. +// This option is most useful for setting a custom tls.Config.RootCAs certificate pool. +// When dialing a multiaddr that contains a /certhash component, this library will set InsecureSkipVerify and +// overwrite the VerifyPeerCertificate callback. +func WithTLSClientConfig(c *tls.Config) Option { + return func(t *transport) error { + t.tlsClientConf = c + return nil + } +} + type transport struct { privKey ic.PrivKey pid peer.ID @@ -54,6 +76,8 @@ type transport struct { listenOnce sync.Once listenOnceErr error certManager *certManager + staticTLSConf *tls.Config + tlsClientConf *tls.Config noise *noise.Transport } @@ -129,15 +153,21 @@ func (t *transport) Dial(ctx context.Context, raddr ma.Multiaddr, p peer.ID) (tp func (t *transport) dial(ctx context.Context, addr string, certHashes []multihash.DecodedMultihash) (*webtransport.Session, error) { url := fmt.Sprintf("https://%s%s", addr, webtransportHTTPEndpoint) + var tlsConf *tls.Config + if t.tlsClientConf != nil { + tlsConf = t.tlsClientConf.Clone() + } else { + tlsConf = &tls.Config{} + } + + if len(certHashes) > 0 { + tlsConf.InsecureSkipVerify = true // this is not insecure. We verify the certificate ourselves. + tlsConf.VerifyPeerCertificate = func(rawCerts [][]byte, _ [][]*x509.Certificate) error { + return verifyRawCerts(rawCerts, certHashes) + } + } dialer := webtransport.Dialer{ - RoundTripper: &http3.RoundTripper{ - TLSClientConfig: &tls.Config{ - InsecureSkipVerify: true, // this is not insecure. We verify the certificate ourselves. - VerifyPeerCertificate: func(rawCerts [][]byte, _ [][]*x509.Certificate) error { - return verifyRawCerts(rawCerts, certHashes) - }, - }, - }, + RoundTripper: &http3.RoundTripper{TLSClientConfig: tlsConf}, } rsp, sess, err := dialer.Dial(ctx, url, nil) if err != nil { @@ -193,6 +223,14 @@ func (t *transport) checkEarlyData(b []byte) error { return fmt.Errorf("failed to unmarshal early data protobuf: %w", err) } hashes := make([]multihash.DecodedMultihash, 0, len(msg.CertHashes)) + + if t.staticTLSConf != nil { + if len(hashes) > 0 { + return errors.New("using static TLS config, didn't expect any certificate hashes") + } + return nil + } + for _, h := range msg.CertHashes { dh, err := multihash.Decode(h) if err != nil { @@ -224,13 +262,15 @@ func (t *transport) Listen(laddr ma.Multiaddr) (tpt.Listener, error) { if !webtransportMatcher.Matches(laddr) { return nil, fmt.Errorf("cannot listen on non-WebTransport addr: %s", laddr) } - t.listenOnce.Do(func() { - t.certManager, t.listenOnceErr = newCertManager(t.clock) - }) - if t.listenOnceErr != nil { - return nil, t.listenOnceErr + if t.staticTLSConf == nil { + t.listenOnce.Do(func() { + t.certManager, t.listenOnceErr = newCertManager(t.clock) + }) + if t.listenOnceErr != nil { + return nil, t.listenOnceErr + } } - return newListener(laddr, t, t.noise, t.certManager, t.gater, t.rcmgr) + return newListener(laddr, t, t.noise, t.certManager, t.staticTLSConf, t.gater, t.rcmgr) } func (t *transport) Protocols() []int { diff --git a/p2p/transport/webtransport/transport_test.go b/p2p/transport/webtransport/transport_test.go index 61cd245138..e83cac02e5 100644 --- a/p2p/transport/webtransport/transport_test.go +++ b/p2p/transport/webtransport/transport_test.go @@ -2,12 +2,19 @@ package libp2pwebtransport_test import ( "context" + "crypto/ecdsa" + "crypto/elliptic" "crypto/rand" "crypto/sha256" + "crypto/tls" + "crypto/x509" + "crypto/x509/pkix" "errors" "fmt" "io" + "math/big" "net" + "strings" "testing" "time" @@ -64,6 +71,19 @@ func stripCertHashes(addr ma.Multiaddr) ma.Multiaddr { } } +// create a /certhash multiaddr component using the SHA256 of foobar +func getCerthashComponent(t *testing.T, b []byte) ma.Multiaddr { + t.Helper() + h := sha256.Sum256(b) + mh, err := multihash.Encode(h[:], multihash.SHA2_256) + require.NoError(t, err) + certStr, err := multibase.Encode(multibase.Base58BTC, mh) + require.NoError(t, err) + ha, err := ma.NewComponent(ma.ProtocolWithCode(ma.P_CERTHASH).Name, certStr) + require.NoError(t, err) + return ha +} + func TestTransport(t *testing.T) { serverID, serverKey := newIdentity(t) tr, err := libp2pwebtransport.New(serverKey, nil, network.NullResourceManager) @@ -129,29 +149,11 @@ func TestHashVerification(t *testing.T) { require.NoError(t, err) defer tr2.(io.Closer).Close() - // create a hash component using the SHA256 of foobar - h := sha256.Sum256([]byte("foobar")) - mh, err := multihash.Encode(h[:], multihash.SHA2_256) - require.NoError(t, err) - certStr, err := multibase.Encode(multibase.Base58BTC, mh) - require.NoError(t, err) - foobarHash, err := ma.NewComponent(ma.ProtocolWithCode(ma.P_CERTHASH).Name, certStr) - require.NoError(t, err) + foobarHash := getCerthashComponent(t, []byte("foobar")) t.Run("fails using only a wrong hash", func(t *testing.T) { // replace the certificate hash in the multiaddr with a fake hash - addr := ln.Multiaddr() - // strip off all certhash components - for { - a, comp := ma.SplitLast(addr) - if comp.Protocol().Code != ma.P_CERTHASH { - break - } - addr = a - } - - addr = addr.Encapsulate(foobarHash) - + addr := stripCertHashes(ln.Multiaddr()).Encapsulate(foobarHash) _, err := tr2.Dial(context.Background(), addr, serverID) require.Error(t, err) require.Contains(t, err.Error(), "CRYPTO_ERROR (0x12a): cert hash not found") @@ -424,3 +426,84 @@ func TestConnectionGaterInterceptSecured(t *testing.T) { t.Fatal("timeout") } } + +func getTLSConf(t *testing.T, ip net.IP, start, end time.Time) *tls.Config { + t.Helper() + certTempl := &x509.Certificate{ + SerialNumber: big.NewInt(1234), + Subject: pkix.Name{Organization: []string{"webtransport"}}, + NotBefore: start, + NotAfter: end, + IsCA: true, + ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth, x509.ExtKeyUsageServerAuth}, + KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageCertSign, + BasicConstraintsValid: true, + IPAddresses: []net.IP{ip}, + } + priv, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + require.NoError(t, err) + caBytes, err := x509.CreateCertificate(rand.Reader, certTempl, certTempl, &priv.PublicKey, priv) + require.NoError(t, err) + cert, err := x509.ParseCertificate(caBytes) + require.NoError(t, err) + return &tls.Config{ + Certificates: []tls.Certificate{{ + Certificate: [][]byte{cert.Raw}, + PrivateKey: priv, + Leaf: cert, + }}, + } +} + +func TestStaticTLSConf(t *testing.T) { + tlsConf := getTLSConf(t, net.ParseIP("127.0.0.1"), time.Now(), time.Now().Add(365*24*time.Hour)) + + serverID, serverKey := newIdentity(t) + tr, err := libp2pwebtransport.New(serverKey, nil, network.NullResourceManager, libp2pwebtransport.WithTLSConfig(tlsConf)) + require.NoError(t, err) + defer tr.(io.Closer).Close() + ln, err := tr.Listen(ma.StringCast("/ip4/127.0.0.1/udp/0/quic/webtransport")) + require.NoError(t, err) + defer ln.Close() + require.Empty(t, extractCertHashes(ln.Multiaddr()), "listener address shouldn't contain any certhash") + + t.Run("fails when the certificate is invalid", func(t *testing.T) { + _, key := newIdentity(t) + cl, err := libp2pwebtransport.New(key, nil, network.NullResourceManager) + require.NoError(t, err) + defer cl.(io.Closer).Close() + + _, err = cl.Dial(context.Background(), ln.Multiaddr(), serverID) + require.Error(t, err) + if !strings.Contains(err.Error(), "certificate is not trusted") && + !strings.Contains(err.Error(), "certificate signed by unknown authority") { + t.Fatalf("expected a certificate error, got %+v", err) + } + }) + + t.Run("fails when dialing with a wrong certhash", func(t *testing.T) { + _, key := newIdentity(t) + cl, err := libp2pwebtransport.New(key, nil, network.NullResourceManager) + require.NoError(t, err) + defer cl.(io.Closer).Close() + + addr := ln.Multiaddr().Encapsulate(getCerthashComponent(t, []byte("foo"))) + _, err = cl.Dial(context.Background(), addr, serverID) + require.Error(t, err) + require.Contains(t, err.Error(), "cert hash not found") + }) + + t.Run("accepts a valid TLS certificate", func(t *testing.T) { + _, key := newIdentity(t) + store := x509.NewCertPool() + store.AddCert(tlsConf.Certificates[0].Leaf) + tlsConf := &tls.Config{RootCAs: store} + cl, err := libp2pwebtransport.New(key, nil, network.NullResourceManager, libp2pwebtransport.WithTLSClientConfig(tlsConf)) + require.NoError(t, err) + defer cl.(io.Closer).Close() + + conn, err := cl.Dial(context.Background(), ln.Multiaddr(), serverID) + require.NoError(t, err) + defer conn.Close() + }) +} From 60a40710abe6857ff277683dca486dbb473c50e2 Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Mon, 29 Aug 2022 17:44:52 +0300 Subject: [PATCH 41/44] chore: update CI to Go 1.18 / 1.19, update webtransport-go to v0.1.0 --- p2p/transport/webtransport/cert_manager.go | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/p2p/transport/webtransport/cert_manager.go b/p2p/transport/webtransport/cert_manager.go index c46a60b666..6fc363281d 100644 --- a/p2p/transport/webtransport/cert_manager.go +++ b/p2p/transport/webtransport/cert_manager.go @@ -40,10 +40,10 @@ func newCertConfig(start, end time.Time) (*certConfig, error) { } // Certificate renewal logic: -// 1. On startup, we generate one cert that is valid from now (-1h, to allow for clock skew), and another -// cert that is valid from the expiry date of the first certificate (again, with allowance for clock skew). -// 2. Once we reach 1h before expiry of the first certificate, we switch over to the second certificate. -// At the same time, we stop advertising the certhash of the first cert and generate the next cert. +// 1. On startup, we generate one cert that is valid from now (-1h, to allow for clock skew), and another +// cert that is valid from the expiry date of the first certificate (again, with allowance for clock skew). +// 2. Once we reach 1h before expiry of the first certificate, we switch over to the second certificate. +// At the same time, we stop advertising the certhash of the first cert and generate the next cert. type certManager struct { clock clock.Clock ctx context.Context From 4ce4e4f05edd9ba28650f08d4b9ada7bff8f6362 Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Sat, 3 Sep 2022 11:57:25 +0300 Subject: [PATCH 42/44] only use positive numbers for x509.Certificate serial numbers --- p2p/transport/webtransport/crypto.go | 7 +++++-- p2p/transport/webtransport/crypto_test.go | 6 +++++- 2 files changed, 10 insertions(+), 3 deletions(-) diff --git a/p2p/transport/webtransport/crypto.go b/p2p/transport/webtransport/crypto.go index 9bb7f7a330..0ea71323af 100644 --- a/p2p/transport/webtransport/crypto.go +++ b/p2p/transport/webtransport/crypto.go @@ -37,9 +37,12 @@ func generateCert(start, end time.Time) (*x509.Certificate, *ecdsa.PrivateKey, e if _, err := rand.Read(b); err != nil { return nil, nil, err } - serial := binary.BigEndian.Uint64(b) + serial := int64(binary.BigEndian.Uint64(b)) + if serial < 0 { + serial = -serial + } certTempl := &x509.Certificate{ - SerialNumber: big.NewInt(int64(serial)), + SerialNumber: big.NewInt(serial), Subject: pkix.Name{}, NotBefore: start, NotAfter: end, diff --git a/p2p/transport/webtransport/crypto_test.go b/p2p/transport/webtransport/crypto_test.go index c69c578b92..d6d106202a 100644 --- a/p2p/transport/webtransport/crypto_test.go +++ b/p2p/transport/webtransport/crypto_test.go @@ -31,8 +31,12 @@ func sha256Multihash(t *testing.T, b []byte) multihash.DecodedMultihash { func generateCertWithKey(t *testing.T, key crypto.PrivateKey, start, end time.Time) *x509.Certificate { t.Helper() + serial := int64(mrand.Uint64()) + if serial < 0 { + serial = -serial + } certTempl := &x509.Certificate{ - SerialNumber: big.NewInt(int64(mrand.Uint64())), + SerialNumber: big.NewInt(serial), Subject: pkix.Name{}, NotBefore: start, NotAfter: end, From 3521b4fae8c1e930bdf1b895a123a9eb7037f34f Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Sat, 3 Sep 2022 12:07:15 +0300 Subject: [PATCH 43/44] chore: update go-multiaddr to v0.6.0 --- p2p/transport/webtransport/multiaddr_test.go | 12 ------------ p2p/transport/webtransport/transport_test.go | 5 ++++- 2 files changed, 4 insertions(+), 13 deletions(-) diff --git a/p2p/transport/webtransport/multiaddr_test.go b/p2p/transport/webtransport/multiaddr_test.go index 7c95b00d01..3af1ce56e2 100644 --- a/p2p/transport/webtransport/multiaddr_test.go +++ b/p2p/transport/webtransport/multiaddr_test.go @@ -73,16 +73,4 @@ func TestExtractCertHashes(t *testing.T) { require.Equal(t, h, string(ch[i].Digest)) } } - - // invalid cases - for _, tc := range [...]struct { - addr string - err string - }{ - {addr: fmt.Sprintf("/ip4/127.0.0.1/udp/1234/quic/webtransport/certhash/%s", fooHash[:len(fooHash)-1]), err: "failed to multihash-decode certificate hash"}, - } { - _, err := extractCertHashes(ma.StringCast(tc.addr)) - require.Error(t, err) - require.Contains(t, err.Error(), tc.err) - } } diff --git a/p2p/transport/webtransport/transport_test.go b/p2p/transport/webtransport/transport_test.go index e83cac02e5..d436af4bd2 100644 --- a/p2p/transport/webtransport/transport_test.go +++ b/p2p/transport/webtransport/transport_test.go @@ -43,9 +43,12 @@ func newIdentity(t *testing.T) (peer.ID, ic.PrivKey) { } func randomMultihash(t *testing.T) string { + t.Helper() b := make([]byte, 16) rand.Read(b) - s, err := multibase.Encode(multibase.Base32hex, b) + h, err := multihash.Encode(b, multihash.KECCAK_224) + require.NoError(t, err) + s, err := multibase.Encode(multibase.Base32hex, h) require.NoError(t, err) return s } From 97e739f0a872abf72f75c26549aa70870140ddd9 Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Sat, 3 Sep 2022 13:52:50 +0300 Subject: [PATCH 44/44] update to the current master of go-libp2p (#23) --- p2p/transport/webtransport/conn.go | 4 +- p2p/transport/webtransport/listener.go | 75 +++++++++++++------ .../mock_connection_gater_test.go | 8 +- .../webtransport/noise_early_data.go | 34 +++++++++ p2p/transport/webtransport/stream.go | 2 +- p2p/transport/webtransport/transport.go | 48 ++++-------- p2p/transport/webtransport/transport_test.go | 10 +-- 7 files changed, 113 insertions(+), 68 deletions(-) create mode 100644 p2p/transport/webtransport/noise_early_data.go diff --git a/p2p/transport/webtransport/conn.go b/p2p/transport/webtransport/conn.go index 409cea3383..f68595293f 100644 --- a/p2p/transport/webtransport/conn.go +++ b/p2p/transport/webtransport/conn.go @@ -3,8 +3,8 @@ package libp2pwebtransport import ( "context" - "github.com/libp2p/go-libp2p-core/network" - tpt "github.com/libp2p/go-libp2p-core/transport" + "github.com/libp2p/go-libp2p/core/network" + tpt "github.com/libp2p/go-libp2p/core/transport" "github.com/marten-seemann/webtransport-go" ma "github.com/multiformats/go-multiaddr" diff --git a/p2p/transport/webtransport/listener.go b/p2p/transport/webtransport/listener.go index 872e65804b..4c87512d29 100644 --- a/p2p/transport/webtransport/listener.go +++ b/p2p/transport/webtransport/listener.go @@ -5,15 +5,16 @@ import ( "crypto/tls" "errors" "fmt" + pb "github.com/marten-seemann/go-libp2p-webtransport/pb" + "github.com/multiformats/go-multihash" "net" "net/http" "time" - "github.com/libp2p/go-libp2p-core/connmgr" - "github.com/libp2p/go-libp2p-core/network" - tpt "github.com/libp2p/go-libp2p-core/transport" - - noise "github.com/libp2p/go-libp2p-noise" + "github.com/libp2p/go-libp2p/core/connmgr" + "github.com/libp2p/go-libp2p/core/network" + tpt "github.com/libp2p/go-libp2p/core/transport" + "github.com/libp2p/go-libp2p/p2p/security/noise" "github.com/lucas-clemente/quic-go/http3" "github.com/marten-seemann/webtransport-go" @@ -27,10 +28,11 @@ const queueLen = 16 const handshakeTimeout = 10 * time.Second type listener struct { - transport tpt.Transport - noise *noise.Transport - certManager *certManager - staticTLSConf *tls.Config + transport tpt.Transport + noise *noise.Transport + certManager *certManager + tlsConf *tls.Config + isStaticTLSConf bool rcmgr network.ResourceManager gater connmgr.ConnectionGater @@ -67,23 +69,25 @@ func newListener(laddr ma.Multiaddr, transport tpt.Transport, noise *noise.Trans if err != nil { return nil, err } + isStaticTLSConf := tlsConf != nil if tlsConf == nil { tlsConf = &tls.Config{GetConfigForClient: func(*tls.ClientHelloInfo) (*tls.Config, error) { return certManager.GetConfig(), nil }} } ln := &listener{ - transport: transport, - noise: noise, - certManager: certManager, - staticTLSConf: tlsConf, - rcmgr: rcmgr, - gater: gater, - queue: make(chan tpt.CapableConn, queueLen), - serverClosed: make(chan struct{}), - addr: udpConn.LocalAddr(), - multiaddr: localMultiaddr, - server: webtransport.Server{H3: http3.Server{TLSConfig: tlsConf}}, + transport: transport, + noise: noise, + certManager: certManager, + tlsConf: tlsConf, + isStaticTLSConf: isStaticTLSConf, + rcmgr: rcmgr, + gater: gater, + queue: make(chan tpt.CapableConn, queueLen), + serverClosed: make(chan struct{}), + addr: udpConn.LocalAddr(), + multiaddr: localMultiaddr, + server: webtransport.Server{H3: http3.Server{TLSConfig: tlsConf}}, } ln.ctx, ln.ctxCancel = context.WithCancel(context.Background()) mux := http.NewServeMux() @@ -184,7 +188,11 @@ func (l *listener) handshake(ctx context.Context, sess *webtransport.Session) (c if err != nil { return nil, err } - c, err := l.noise.SecureInbound(ctx, &webtransportStream{Stream: str, wsess: sess}, "") + n, err := l.noise.WithSessionOptions(noise.EarlyData(newEarlyDataReceiver(l.checkEarlyData))) + if err != nil { + return nil, fmt.Errorf("failed to initialize Noise session: %w", err) + } + c, err := n.SecureInbound(ctx, &webtransportStream{Stream: str, wsess: sess}, "") if err != nil { return nil, err } @@ -195,6 +203,31 @@ func (l *listener) handshake(ctx context.Context, sess *webtransport.Session) (c }, nil } +func (l *listener) checkEarlyData(b []byte) error { + var msg pb.WebTransport + if err := msg.Unmarshal(b); err != nil { + fmt.Println(1) + return fmt.Errorf("failed to unmarshal early data protobuf: %w", err) + } + + if l.isStaticTLSConf { + if len(msg.CertHashes) > 0 { + return errors.New("using static TLS config, didn't expect any certificate hashes") + } + return nil + } + + hashes := make([]multihash.DecodedMultihash, 0, len(msg.CertHashes)) + for _, h := range msg.CertHashes { + dh, err := multihash.Decode(h) + if err != nil { + return fmt.Errorf("failed to decode hash: %w", err) + } + hashes = append(hashes, *dh) + } + return l.certManager.Verify(hashes) +} + func (l *listener) Addr() net.Addr { return l.addr } diff --git a/p2p/transport/webtransport/mock_connection_gater_test.go b/p2p/transport/webtransport/mock_connection_gater_test.go index 071ed7494f..c6e7dbaad6 100644 --- a/p2p/transport/webtransport/mock_connection_gater_test.go +++ b/p2p/transport/webtransport/mock_connection_gater_test.go @@ -1,5 +1,5 @@ // Code generated by MockGen. DO NOT EDIT. -// Source: github.com/libp2p/go-libp2p-core/connmgr (interfaces: ConnectionGater) +// Source: github.com/libp2p/go-libp2p/core/connmgr (interfaces: ConnectionGater) // Package libp2pwebtransport_test is a generated GoMock package. package libp2pwebtransport_test @@ -8,9 +8,9 @@ import ( reflect "reflect" gomock "github.com/golang/mock/gomock" - control "github.com/libp2p/go-libp2p-core/control" - network "github.com/libp2p/go-libp2p-core/network" - peer "github.com/libp2p/go-libp2p-core/peer" + control "github.com/libp2p/go-libp2p/core/control" + network "github.com/libp2p/go-libp2p/core/network" + peer "github.com/libp2p/go-libp2p/core/peer" multiaddr "github.com/multiformats/go-multiaddr" ) diff --git a/p2p/transport/webtransport/noise_early_data.go b/p2p/transport/webtransport/noise_early_data.go new file mode 100644 index 0000000000..ec01c6d7a2 --- /dev/null +++ b/p2p/transport/webtransport/noise_early_data.go @@ -0,0 +1,34 @@ +package libp2pwebtransport + +import ( + "context" + "github.com/libp2p/go-libp2p/core/peer" + "github.com/libp2p/go-libp2p/p2p/security/noise" + "net" +) + +type earlyDataHandler struct { + earlyData []byte + receive func([]byte) error +} + +var _ noise.EarlyDataHandler = &earlyDataHandler{} + +func newEarlyDataSender(earlyData []byte) noise.EarlyDataHandler { + return &earlyDataHandler{earlyData: earlyData} +} + +func newEarlyDataReceiver(receive func([]byte) error) noise.EarlyDataHandler { + return &earlyDataHandler{receive: receive} +} + +func (e *earlyDataHandler) Send(context.Context, net.Conn, peer.ID) []byte { + return e.earlyData +} + +func (e *earlyDataHandler) Received(_ context.Context, _ net.Conn, b []byte) error { + if e.receive == nil { + return nil + } + return e.receive(b) +} diff --git a/p2p/transport/webtransport/stream.go b/p2p/transport/webtransport/stream.go index 6aa58cb8d8..ff17b3083f 100644 --- a/p2p/transport/webtransport/stream.go +++ b/p2p/transport/webtransport/stream.go @@ -6,7 +6,7 @@ import ( "github.com/marten-seemann/webtransport-go" - "github.com/libp2p/go-libp2p-core/network" + "github.com/libp2p/go-libp2p/core/network" ) const ( diff --git a/p2p/transport/webtransport/transport.go b/p2p/transport/webtransport/transport.go index e027e1fe8f..faa13db86c 100644 --- a/p2p/transport/webtransport/transport.go +++ b/p2p/transport/webtransport/transport.go @@ -4,7 +4,6 @@ import ( "context" "crypto/tls" "crypto/x509" - "errors" "fmt" "io" "sync" @@ -12,13 +11,12 @@ import ( pb "github.com/marten-seemann/go-libp2p-webtransport/pb" - "github.com/libp2p/go-libp2p-core/connmgr" - ic "github.com/libp2p/go-libp2p-core/crypto" - "github.com/libp2p/go-libp2p-core/network" - "github.com/libp2p/go-libp2p-core/peer" - tpt "github.com/libp2p/go-libp2p-core/transport" - - noise "github.com/libp2p/go-libp2p-noise" + "github.com/libp2p/go-libp2p/core/connmgr" + ic "github.com/libp2p/go-libp2p/core/crypto" + "github.com/libp2p/go-libp2p/core/network" + "github.com/libp2p/go-libp2p/core/peer" + tpt "github.com/libp2p/go-libp2p/core/transport" + "github.com/libp2p/go-libp2p/p2p/security/noise" "github.com/benbjohnson/clock" logging "github.com/ipfs/go-log/v2" @@ -102,11 +100,11 @@ func New(key ic.PrivKey, gater connmgr.ConnectionGater, rcmgr network.ResourceMa return nil, err } } - noise, err := noise.New(key, noise.WithEarlyDataHandler(t.checkEarlyData)) + n, err := noise.New(key) if err != nil { return nil, err } - t.noise = noise + t.noise = n return t, nil } @@ -207,7 +205,11 @@ func (t *transport) upgrade(ctx context.Context, sess *webtransport.Session, p p if err != nil { return nil, fmt.Errorf("failed to marshal WebTransport protobuf: %w", err) } - c, err := t.noise.SecureOutboundWithEarlyData(ctx, &webtransportStream{Stream: str, wsess: sess}, p, msgBytes) + n, err := t.noise.WithSessionOptions(noise.EarlyData(newEarlyDataSender(msgBytes))) + if err != nil { + return nil, fmt.Errorf("failed to create Noise transport: %w", err) + } + c, err := n.SecureOutbound(ctx, &webtransportStream{Stream: str, wsess: sess}, p) if err != nil { return nil, err } @@ -217,30 +219,6 @@ func (t *transport) upgrade(ctx context.Context, sess *webtransport.Session, p p }, nil } -func (t *transport) checkEarlyData(b []byte) error { - var msg pb.WebTransport - if err := msg.Unmarshal(b); err != nil { - return fmt.Errorf("failed to unmarshal early data protobuf: %w", err) - } - hashes := make([]multihash.DecodedMultihash, 0, len(msg.CertHashes)) - - if t.staticTLSConf != nil { - if len(hashes) > 0 { - return errors.New("using static TLS config, didn't expect any certificate hashes") - } - return nil - } - - for _, h := range msg.CertHashes { - dh, err := multihash.Decode(h) - if err != nil { - return fmt.Errorf("failed to decode hash: %w", err) - } - hashes = append(hashes, *dh) - } - return t.certManager.Verify(hashes) -} - func (t *transport) CanDial(addr ma.Multiaddr) bool { var numHashes int ma.ForEach(addr, func(c ma.Component) bool { diff --git a/p2p/transport/webtransport/transport_test.go b/p2p/transport/webtransport/transport_test.go index d436af4bd2..02f065bc8c 100644 --- a/p2p/transport/webtransport/transport_test.go +++ b/p2p/transport/webtransport/transport_test.go @@ -20,12 +20,12 @@ import ( libp2pwebtransport "github.com/marten-seemann/go-libp2p-webtransport" - ic "github.com/libp2p/go-libp2p-core/crypto" - "github.com/libp2p/go-libp2p-core/network" - "github.com/libp2p/go-libp2p-core/peer" + ic "github.com/libp2p/go-libp2p/core/crypto" + "github.com/libp2p/go-libp2p/core/network" + mocknetwork "github.com/libp2p/go-libp2p/core/network/mocks" + "github.com/libp2p/go-libp2p/core/peer" "github.com/golang/mock/gomock" - mocknetwork "github.com/libp2p/go-libp2p-testing/mocks/network" ma "github.com/multiformats/go-multiaddr" manet "github.com/multiformats/go-multiaddr/net" "github.com/multiformats/go-multibase" @@ -338,7 +338,7 @@ func TestResourceManagerListening(t *testing.T) { } // TODO: unify somehow. We do the same in libp2pquic. -//go:generate sh -c "mockgen -package libp2pwebtransport_test -destination mock_connection_gater_test.go github.com/libp2p/go-libp2p-core/connmgr ConnectionGater && goimports -w mock_connection_gater_test.go" +//go:generate sh -c "mockgen -package libp2pwebtransport_test -destination mock_connection_gater_test.go github.com/libp2p/go-libp2p/core/connmgr ConnectionGater && goimports -w mock_connection_gater_test.go" func TestConnectionGaterDialing(t *testing.T) { ctrl := gomock.NewController(t)