diff --git a/p2p/transport/webtransport/cert_manager.go b/p2p/transport/webtransport/cert_manager.go new file mode 100644 index 0000000000..6fc363281d --- /dev/null +++ b/p2p/transport/webtransport/cert_manager.go @@ -0,0 +1,181 @@ +package libp2pwebtransport + +import ( + "bytes" + "context" + "crypto/sha256" + "crypto/tls" + "fmt" + "sync" + "time" + + "github.com/benbjohnson/clock" + ma "github.com/multiformats/go-multiaddr" + "github.com/multiformats/go-multibase" + "github.com/multiformats/go-multihash" +) + +// Allow for a bit of clock skew. +// When we generate a certificate, the NotBefore time is set to clockSkewAllowance before the current time. +// Similarly, we stop using a certificate one clockSkewAllowance before its expiry time. +const clockSkewAllowance = time.Hour + +type certConfig struct { + tlsConf *tls.Config + sha256 [32]byte // cached from the tlsConf +} + +func (c *certConfig) Start() time.Time { return c.tlsConf.Certificates[0].Leaf.NotBefore } +func (c *certConfig) End() time.Time { return c.tlsConf.Certificates[0].Leaf.NotAfter } + +func newCertConfig(start, end time.Time) (*certConfig, error) { + conf, err := getTLSConf(start, end) + if err != nil { + return nil, err + } + return &certConfig{ + tlsConf: conf, + sha256: sha256.Sum256(conf.Certificates[0].Leaf.Raw), + }, nil +} + +// Certificate renewal logic: +// 1. On startup, we generate one cert that is valid from now (-1h, to allow for clock skew), and another +// cert that is valid from the expiry date of the first certificate (again, with allowance for clock skew). +// 2. Once we reach 1h before expiry of the first certificate, we switch over to the second certificate. +// At the same time, we stop advertising the certhash of the first cert and generate the next cert. +type certManager struct { + clock clock.Clock + ctx context.Context + ctxCancel context.CancelFunc + refCount sync.WaitGroup + + mx sync.RWMutex + lastConfig *certConfig // initially nil + currentConfig *certConfig + nextConfig *certConfig // nil until we have passed half the certValidity of the current config + addrComp ma.Multiaddr +} + +func newCertManager(clock clock.Clock) (*certManager, error) { + m := &certManager{clock: clock} + m.ctx, m.ctxCancel = context.WithCancel(context.Background()) + if err := m.init(); err != nil { + return nil, err + } + + m.background() + return m, nil +} + +func (m *certManager) init() error { + start := m.clock.Now().Add(-clockSkewAllowance) + var err error + m.nextConfig, err = newCertConfig(start, start.Add(certValidity)) + if err != nil { + return err + } + return m.rollConfig() +} + +func (m *certManager) rollConfig() error { + // We stop using the current certificate clockSkewAllowance before its expiry time. + // At this point, the next certificate needs to be valid for one clockSkewAllowance. + nextStart := m.nextConfig.End().Add(-2 * clockSkewAllowance) + c, err := newCertConfig(nextStart, nextStart.Add(certValidity)) + if err != nil { + return err + } + m.lastConfig = m.currentConfig + m.currentConfig = m.nextConfig + m.nextConfig = c + return m.cacheAddrComponent() +} + +func (m *certManager) background() { + d := m.currentConfig.End().Add(-clockSkewAllowance).Sub(m.clock.Now()) + log.Debugw("setting timer", "duration", d.String()) + t := m.clock.Timer(d) + m.refCount.Add(1) + + go func() { + defer m.refCount.Done() + defer t.Stop() + + for { + select { + case <-m.ctx.Done(): + return + case now := <-t.C: + m.mx.Lock() + if err := m.rollConfig(); err != nil { + log.Errorw("rolling config failed", "error", err) + } + d := m.currentConfig.End().Add(-clockSkewAllowance).Sub(now) + log.Debugw("rolling certificates", "next", d.String()) + t.Reset(d) + m.mx.Unlock() + } + } + }() +} + +func (m *certManager) GetConfig() *tls.Config { + m.mx.RLock() + defer m.mx.RUnlock() + return m.currentConfig.tlsConf +} + +func (m *certManager) AddrComponent() ma.Multiaddr { + m.mx.RLock() + defer m.mx.RUnlock() + return m.addrComp +} + +func (m *certManager) Verify(hashes []multihash.DecodedMultihash) error { + for _, h := range hashes { + if h.Code != multihash.SHA2_256 { + return fmt.Errorf("expected SHA256 hash, got %d", h.Code) + } + if !bytes.Equal(h.Digest, m.currentConfig.sha256[:]) && + (m.nextConfig == nil || !bytes.Equal(h.Digest, m.nextConfig.sha256[:])) && + (m.lastConfig == nil || !bytes.Equal(h.Digest, m.lastConfig.sha256[:])) { + return fmt.Errorf("found unexpected hash: %+x", h.Digest) + } + } + return nil +} + +func (m *certManager) cacheAddrComponent() error { + addr, err := m.addrComponentForCert(m.currentConfig.sha256[:]) + if err != nil { + return err + } + if m.nextConfig != nil { + comp, err := m.addrComponentForCert(m.nextConfig.sha256[:]) + if err != nil { + return err + } + addr = addr.Encapsulate(comp) + } + m.addrComp = addr + return nil +} + +func (m *certManager) addrComponentForCert(hash []byte) (ma.Multiaddr, error) { + mh, err := multihash.Encode(hash, multihash.SHA2_256) + if err != nil { + return nil, err + } + certStr, err := multibase.Encode(multibase.Base58BTC, mh) + if err != nil { + return nil, err + } + return ma.NewComponent(ma.ProtocolWithCode(ma.P_CERTHASH).Name, certStr) +} + +func (m *certManager) Close() error { + m.ctxCancel() + m.refCount.Wait() + return nil +} diff --git a/p2p/transport/webtransport/cert_manager_test.go b/p2p/transport/webtransport/cert_manager_test.go new file mode 100644 index 0000000000..3f2328fbb7 --- /dev/null +++ b/p2p/transport/webtransport/cert_manager_test.go @@ -0,0 +1,102 @@ +package libp2pwebtransport + +import ( + "crypto/sha256" + "crypto/tls" + "testing" + "time" + + "github.com/benbjohnson/clock" + ma "github.com/multiformats/go-multiaddr" + "github.com/multiformats/go-multibase" + "github.com/multiformats/go-multihash" + "github.com/stretchr/testify/require" +) + +func certificateHashFromTLSConfig(c *tls.Config) [32]byte { + return sha256.Sum256(c.Certificates[0].Certificate[0]) +} + +func splitMultiaddr(addr ma.Multiaddr) []ma.Component { + var components []ma.Component + ma.ForEach(addr, func(c ma.Component) bool { + components = append(components, c) + return true + }) + return components +} + +func certHashFromComponent(t *testing.T, comp ma.Component) []byte { + t.Helper() + _, data, err := multibase.Decode(comp.Value()) + require.NoError(t, err) + mh, err := multihash.Decode(data) + require.NoError(t, err) + require.Equal(t, uint64(multihash.SHA2_256), mh.Code) + return mh.Digest +} + +func TestInitialCert(t *testing.T) { + cl := clock.NewMock() + cl.Add(1234567 * time.Hour) + m, err := newCertManager(cl) + require.NoError(t, err) + defer m.Close() + + conf := m.GetConfig() + require.Len(t, conf.Certificates, 1) + cert := conf.Certificates[0] + require.Equal(t, cl.Now().Add(-clockSkewAllowance).UTC(), cert.Leaf.NotBefore) + require.Equal(t, cert.Leaf.NotBefore.Add(certValidity), cert.Leaf.NotAfter) + addr := m.AddrComponent() + components := splitMultiaddr(addr) + require.Len(t, components, 2) + require.Equal(t, ma.P_CERTHASH, components[0].Protocol().Code) + hash := certificateHashFromTLSConfig(conf) + require.Equal(t, hash[:], certHashFromComponent(t, components[0])) + require.Equal(t, ma.P_CERTHASH, components[1].Protocol().Code) +} + +func TestCertRenewal(t *testing.T) { + cl := clock.NewMock() + m, err := newCertManager(cl) + require.NoError(t, err) + defer m.Close() + + firstConf := m.GetConfig() + first := splitMultiaddr(m.AddrComponent()) + require.Len(t, first, 2) + require.NotEqual(t, first[0].Value(), first[1].Value(), "the hashes should differ") + // wait for a new certificate to be generated + cl.Add(certValidity - 2*clockSkewAllowance - time.Second) + require.Never(t, func() bool { + for i, c := range splitMultiaddr(m.AddrComponent()) { + if c.Value() != first[i].Value() { + return true + } + } + return false + }, 100*time.Millisecond, 10*time.Millisecond) + cl.Add(2 * time.Second) + require.Eventually(t, func() bool { return m.GetConfig() != firstConf }, 200*time.Millisecond, 10*time.Millisecond) + secondConf := m.GetConfig() + + second := splitMultiaddr(m.AddrComponent()) + require.Len(t, second, 2) + for _, c := range second { + require.Equal(t, ma.P_CERTHASH, c.Protocol().Code) + } + // check that the 2nd certificate from the beginning was rolled over to be the 1st certificate + require.Equal(t, first[1].Value(), second[0].Value()) + require.NotEqual(t, first[0].Value(), second[1].Value()) + + cl.Add(certValidity - 2*clockSkewAllowance + time.Second) + require.Eventually(t, func() bool { return m.GetConfig() != secondConf }, 200*time.Millisecond, 10*time.Millisecond) + third := splitMultiaddr(m.AddrComponent()) + require.Len(t, third, 2) + for _, c := range third { + require.Equal(t, ma.P_CERTHASH, c.Protocol().Code) + } + // check that the 2nd certificate from the beginning was rolled over to be the 1st certificate + require.Equal(t, second[1].Value(), third[0].Value()) +} diff --git a/p2p/transport/webtransport/conn.go b/p2p/transport/webtransport/conn.go new file mode 100644 index 0000000000..f68595293f --- /dev/null +++ b/p2p/transport/webtransport/conn.go @@ -0,0 +1,65 @@ +package libp2pwebtransport + +import ( + "context" + + "github.com/libp2p/go-libp2p/core/network" + tpt "github.com/libp2p/go-libp2p/core/transport" + + "github.com/marten-seemann/webtransport-go" + ma "github.com/multiformats/go-multiaddr" +) + +type connSecurityMultiaddrs interface { + network.ConnMultiaddrs + network.ConnSecurity +} + +type connSecurityMultiaddrsImpl struct { + network.ConnSecurity + network.ConnMultiaddrs +} + +type connMultiaddrs struct { + local, remote ma.Multiaddr +} + +var _ network.ConnMultiaddrs = &connMultiaddrs{} + +func (c *connMultiaddrs) LocalMultiaddr() ma.Multiaddr { return c.local } +func (c *connMultiaddrs) RemoteMultiaddr() ma.Multiaddr { return c.remote } + +type conn struct { + connSecurityMultiaddrs + + transport tpt.Transport + session *webtransport.Session + + scope network.ConnScope +} + +var _ tpt.CapableConn = &conn{} + +func newConn(tr tpt.Transport, sess *webtransport.Session, sconn connSecurityMultiaddrs, scope network.ConnScope) *conn { + return &conn{ + connSecurityMultiaddrs: sconn, + transport: tr, + session: sess, + scope: scope, + } +} + +func (c *conn) OpenStream(ctx context.Context) (network.MuxedStream, error) { + str, err := c.session.OpenStreamSync(ctx) + return &stream{str}, err +} + +func (c *conn) AcceptStream() (network.MuxedStream, error) { + str, err := c.session.AcceptStream(context.Background()) + return &stream{str}, err +} + +func (c *conn) Close() error { return c.session.Close() } +func (c *conn) IsClosed() bool { return c.session.Context().Err() != nil } +func (c *conn) Scope() network.ConnScope { return c.scope } +func (c *conn) Transport() tpt.Transport { return c.transport } diff --git a/p2p/transport/webtransport/crypto.go b/p2p/transport/webtransport/crypto.go new file mode 100644 index 0000000000..0ea71323af --- /dev/null +++ b/p2p/transport/webtransport/crypto.go @@ -0,0 +1,108 @@ +package libp2pwebtransport + +import ( + "bytes" + "crypto/ecdsa" + "crypto/elliptic" + "crypto/rand" + "crypto/sha256" + "crypto/tls" + "crypto/x509" + "crypto/x509/pkix" + "encoding/binary" + "errors" + "fmt" + "math/big" + "time" + + "github.com/multiformats/go-multihash" +) + +func getTLSConf(start, end time.Time) (*tls.Config, error) { + cert, priv, err := generateCert(start, end) + if err != nil { + return nil, err + } + return &tls.Config{ + Certificates: []tls.Certificate{{ + Certificate: [][]byte{cert.Raw}, + PrivateKey: priv, + Leaf: cert, + }}, + }, nil +} + +func generateCert(start, end time.Time) (*x509.Certificate, *ecdsa.PrivateKey, error) { + b := make([]byte, 8) + if _, err := rand.Read(b); err != nil { + return nil, nil, err + } + serial := int64(binary.BigEndian.Uint64(b)) + if serial < 0 { + serial = -serial + } + certTempl := &x509.Certificate{ + SerialNumber: big.NewInt(serial), + Subject: pkix.Name{}, + NotBefore: start, + NotAfter: end, + IsCA: true, + ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth, x509.ExtKeyUsageServerAuth}, + KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageCertSign, + BasicConstraintsValid: true, + } + caPrivateKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + if err != nil { + return nil, nil, err + } + caBytes, err := x509.CreateCertificate(rand.Reader, certTempl, certTempl, &caPrivateKey.PublicKey, caPrivateKey) + if err != nil { + return nil, nil, err + } + ca, err := x509.ParseCertificate(caBytes) + if err != nil { + return nil, nil, err + } + return ca, caPrivateKey, nil +} + +func verifyRawCerts(rawCerts [][]byte, certHashes []multihash.DecodedMultihash) error { + if len(rawCerts) < 1 { + return errors.New("no cert") + } + leaf := rawCerts[len(rawCerts)-1] + // The W3C WebTransport specification currently only allows SHA-256 certificates for serverCertificateHashes. + hash := sha256.Sum256(leaf) + var verified bool + for _, h := range certHashes { + if h.Code == multihash.SHA2_256 && bytes.Equal(h.Digest, hash[:]) { + verified = true + break + } + } + if !verified { + digests := make([][]byte, 0, len(certHashes)) + for _, h := range certHashes { + digests = append(digests, h.Digest) + } + return fmt.Errorf("cert hash not found: %#x (expected: %#x)", hash, digests) + } + + cert, err := x509.ParseCertificate(leaf) + if err != nil { + return err + } + // TODO: is this the best (and complete?) way to identify RSA certificates? + switch cert.SignatureAlgorithm { + case x509.SHA1WithRSA, x509.SHA256WithRSA, x509.SHA384WithRSA, x509.SHA512WithRSA, x509.MD2WithRSA, x509.MD5WithRSA: + return errors.New("cert uses RSA") + } + if l := cert.NotAfter.Sub(cert.NotBefore); l > 14*24*time.Hour { + return fmt.Errorf("cert must not be valid for longer than 14 days (NotBefore: %s, NotAfter: %s, Length: %s)", cert.NotBefore, cert.NotAfter, l) + } + now := time.Now() + if now.Before(cert.NotBefore) || now.After(cert.NotAfter) { + return fmt.Errorf("cert not valid (NotBefore: %s, NotAfter: %s)", cert.NotBefore, cert.NotAfter) + } + return nil +} diff --git a/p2p/transport/webtransport/crypto_test.go b/p2p/transport/webtransport/crypto_test.go new file mode 100644 index 0000000000..d6d106202a --- /dev/null +++ b/p2p/transport/webtransport/crypto_test.go @@ -0,0 +1,132 @@ +package libp2pwebtransport + +import ( + "crypto" + "crypto/ecdsa" + "crypto/elliptic" + "crypto/rand" + "crypto/rsa" + "crypto/sha256" + "crypto/x509" + "crypto/x509/pkix" + "fmt" + "math/big" + mrand "math/rand" + "testing" + "time" + + "github.com/multiformats/go-multihash" + "github.com/stretchr/testify/require" +) + +func sha256Multihash(t *testing.T, b []byte) multihash.DecodedMultihash { + t.Helper() + hash := sha256.Sum256(b) + h, err := multihash.Encode(hash[:], multihash.SHA2_256) + require.NoError(t, err) + dh, err := multihash.Decode(h) + require.NoError(t, err) + return *dh +} + +func generateCertWithKey(t *testing.T, key crypto.PrivateKey, start, end time.Time) *x509.Certificate { + t.Helper() + serial := int64(mrand.Uint64()) + if serial < 0 { + serial = -serial + } + certTempl := &x509.Certificate{ + SerialNumber: big.NewInt(serial), + Subject: pkix.Name{}, + NotBefore: start, + NotAfter: end, + IsCA: true, + ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth, x509.ExtKeyUsageServerAuth}, + KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageCertSign, + BasicConstraintsValid: true, + } + caBytes, err := x509.CreateCertificate(rand.Reader, certTempl, certTempl, key.(interface{ Public() crypto.PublicKey }).Public(), key) + require.NoError(t, err) + ca, err := x509.ParseCertificate(caBytes) + require.NoError(t, err) + return ca +} + +func TestCertificateVerification(t *testing.T) { + now := time.Now() + ecdsaKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + require.NoError(t, err) + rsaKey, err := rsa.GenerateKey(rand.Reader, 1024) + require.NoError(t, err) + + t.Run("accepting a valid cert", func(t *testing.T) { + validCert := generateCertWithKey(t, ecdsaKey, now, now.Add(14*24*time.Hour)) + require.NoError(t, verifyRawCerts([][]byte{validCert.Raw}, []multihash.DecodedMultihash{sha256Multihash(t, validCert.Raw)})) + }) + + for _, tc := range [...]struct { + name string + cert *x509.Certificate + errStr string + }{ + { + name: "validitity period too long", + cert: generateCertWithKey(t, ecdsaKey, now, now.Add(15*24*time.Hour)), + errStr: "cert must not be valid for longer than 14 days", + }, + { + name: "uses RSA key", + cert: generateCertWithKey(t, rsaKey, now, now.Add(14*24*time.Hour)), + errStr: "RSA", + }, + { + name: "expired certificate", + cert: generateCertWithKey(t, ecdsaKey, now.Add(-14*24*time.Hour), now), + errStr: "cert not valid", + }, + { + name: "not yet valid", + cert: generateCertWithKey(t, ecdsaKey, now.Add(time.Hour), now.Add(time.Hour+14*24*time.Hour)), + errStr: "cert not valid", + }, + } { + tc := tc + t.Run(fmt.Sprintf("rejecting invalid certificates: %s", tc.name), func(t *testing.T) { + err := verifyRawCerts([][]byte{tc.cert.Raw}, []multihash.DecodedMultihash{sha256Multihash(t, tc.cert.Raw)}) + require.Error(t, err) + require.Contains(t, err.Error(), tc.errStr) + }) + } + + for _, tc := range [...]struct { + name string + certs [][]byte + hashes []multihash.DecodedMultihash + errStr string + }{ + { + name: "no certificates", + hashes: []multihash.DecodedMultihash{sha256Multihash(t, []byte("foobar"))}, + errStr: "no cert", + }, + { + name: "certificate not parseable", + certs: [][]byte{[]byte("foobar")}, + hashes: []multihash.DecodedMultihash{sha256Multihash(t, []byte("foobar"))}, + errStr: "x509: malformed certificate", + }, + { + name: "hash mismatch", + certs: [][]byte{generateCertWithKey(t, ecdsaKey, now, now.Add(15*24*time.Hour)).Raw}, + hashes: []multihash.DecodedMultihash{sha256Multihash(t, []byte("foobar"))}, + errStr: "cert hash not found", + }, + } { + tc := tc + t.Run(fmt.Sprintf("rejecting invalid certificates: %s", tc.name), func(t *testing.T) { + err := verifyRawCerts(tc.certs, tc.hashes) + require.Error(t, err) + require.Contains(t, err.Error(), tc.errStr) + }) + } +} diff --git a/p2p/transport/webtransport/listener.go b/p2p/transport/webtransport/listener.go new file mode 100644 index 0000000000..4c87512d29 --- /dev/null +++ b/p2p/transport/webtransport/listener.go @@ -0,0 +1,247 @@ +package libp2pwebtransport + +import ( + "context" + "crypto/tls" + "errors" + "fmt" + pb "github.com/marten-seemann/go-libp2p-webtransport/pb" + "github.com/multiformats/go-multihash" + "net" + "net/http" + "time" + + "github.com/libp2p/go-libp2p/core/connmgr" + "github.com/libp2p/go-libp2p/core/network" + tpt "github.com/libp2p/go-libp2p/core/transport" + "github.com/libp2p/go-libp2p/p2p/security/noise" + + "github.com/lucas-clemente/quic-go/http3" + "github.com/marten-seemann/webtransport-go" + ma "github.com/multiformats/go-multiaddr" + manet "github.com/multiformats/go-multiaddr/net" +) + +var errClosed = errors.New("closed") + +const queueLen = 16 +const handshakeTimeout = 10 * time.Second + +type listener struct { + transport tpt.Transport + noise *noise.Transport + certManager *certManager + tlsConf *tls.Config + isStaticTLSConf bool + + rcmgr network.ResourceManager + gater connmgr.ConnectionGater + + server webtransport.Server + + ctx context.Context + ctxCancel context.CancelFunc + + serverClosed chan struct{} // is closed when server.Serve returns + + addr net.Addr + multiaddr ma.Multiaddr + + queue chan tpt.CapableConn +} + +var _ tpt.Listener = &listener{} + +func newListener(laddr ma.Multiaddr, transport tpt.Transport, noise *noise.Transport, certManager *certManager, tlsConf *tls.Config, gater connmgr.ConnectionGater, rcmgr network.ResourceManager) (tpt.Listener, error) { + network, addr, err := manet.DialArgs(laddr) + if err != nil { + return nil, err + } + udpAddr, err := net.ResolveUDPAddr(network, addr) + if err != nil { + return nil, err + } + udpConn, err := net.ListenUDP(network, udpAddr) + if err != nil { + return nil, err + } + localMultiaddr, err := toWebtransportMultiaddr(udpConn.LocalAddr()) + if err != nil { + return nil, err + } + isStaticTLSConf := tlsConf != nil + if tlsConf == nil { + tlsConf = &tls.Config{GetConfigForClient: func(*tls.ClientHelloInfo) (*tls.Config, error) { + return certManager.GetConfig(), nil + }} + } + ln := &listener{ + transport: transport, + noise: noise, + certManager: certManager, + tlsConf: tlsConf, + isStaticTLSConf: isStaticTLSConf, + rcmgr: rcmgr, + gater: gater, + queue: make(chan tpt.CapableConn, queueLen), + serverClosed: make(chan struct{}), + addr: udpConn.LocalAddr(), + multiaddr: localMultiaddr, + server: webtransport.Server{H3: http3.Server{TLSConfig: tlsConf}}, + } + ln.ctx, ln.ctxCancel = context.WithCancel(context.Background()) + mux := http.NewServeMux() + mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { + w.Write([]byte("Hello, world!")) + }) + mux.HandleFunc(webtransportHTTPEndpoint, ln.httpHandler) + ln.server.H3.Handler = mux + go func() { + defer close(ln.serverClosed) + defer func() { udpConn.Close() }() + if err := ln.server.Serve(udpConn); err != nil { + // TODO: only output if the server hasn't been closed + log.Debugw("serving failed", "addr", udpConn.LocalAddr(), "error", err) + } + }() + return ln, nil +} + +func (l *listener) httpHandler(w http.ResponseWriter, r *http.Request) { + remoteMultiaddr, err := stringToWebtransportMultiaddr(r.RemoteAddr) + if err != nil { + // This should never happen. + log.Errorw("converting remote address failed", "remote", r.RemoteAddr, "error", err) + w.WriteHeader(http.StatusBadRequest) + return + } + if l.gater != nil && !l.gater.InterceptAccept(&connMultiaddrs{local: l.multiaddr, remote: remoteMultiaddr}) { + w.WriteHeader(http.StatusForbidden) + return + } + + connScope, err := l.rcmgr.OpenConnection(network.DirInbound, false, remoteMultiaddr) + if err != nil { + log.Debugw("resource manager blocked incoming connection", "addr", r.RemoteAddr, "error", err) + w.WriteHeader(http.StatusServiceUnavailable) + return + } + + // TODO: check ?type=multistream URL param + sess, err := l.server.Upgrade(w, r) + if err != nil { + log.Debugw("upgrade failed", "error", err) + // TODO: think about the status code to use here + w.WriteHeader(500) + connScope.Done() + return + } + ctx, cancel := context.WithTimeout(l.ctx, handshakeTimeout) + sconn, err := l.handshake(ctx, sess) + if err != nil { + cancel() + log.Debugw("handshake failed", "error", err) + sess.Close() + connScope.Done() + return + } + cancel() + + if l.gater != nil && !l.gater.InterceptSecured(network.DirInbound, sconn.RemotePeer(), sconn) { + // TODO: can we close with a specific error here? + sess.Close() + connScope.Done() + return + } + + if err := connScope.SetPeer(sconn.RemotePeer()); err != nil { + log.Debugw("resource manager blocked incoming connection for peer", "peer", sconn.RemotePeer(), "addr", r.RemoteAddr, "error", err) + sess.Close() + connScope.Done() + return + } + + // TODO: think about what happens when this channel fills up + l.queue <- newConn(l.transport, sess, sconn, connScope) +} + +func (l *listener) Accept() (tpt.CapableConn, error) { + select { + case <-l.ctx.Done(): + return nil, errClosed + case c := <-l.queue: + return c, nil + } +} + +func (l *listener) handshake(ctx context.Context, sess *webtransport.Session) (connSecurityMultiaddrs, error) { + local, err := toWebtransportMultiaddr(sess.LocalAddr()) + if err != nil { + return nil, fmt.Errorf("error determiniting local addr: %w", err) + } + remote, err := toWebtransportMultiaddr(sess.RemoteAddr()) + if err != nil { + return nil, fmt.Errorf("error determiniting remote addr: %w", err) + } + + str, err := sess.AcceptStream(ctx) + if err != nil { + return nil, err + } + n, err := l.noise.WithSessionOptions(noise.EarlyData(newEarlyDataReceiver(l.checkEarlyData))) + if err != nil { + return nil, fmt.Errorf("failed to initialize Noise session: %w", err) + } + c, err := n.SecureInbound(ctx, &webtransportStream{Stream: str, wsess: sess}, "") + if err != nil { + return nil, err + } + + return &connSecurityMultiaddrsImpl{ + ConnSecurity: c, + ConnMultiaddrs: &connMultiaddrs{local: local, remote: remote}, + }, nil +} + +func (l *listener) checkEarlyData(b []byte) error { + var msg pb.WebTransport + if err := msg.Unmarshal(b); err != nil { + fmt.Println(1) + return fmt.Errorf("failed to unmarshal early data protobuf: %w", err) + } + + if l.isStaticTLSConf { + if len(msg.CertHashes) > 0 { + return errors.New("using static TLS config, didn't expect any certificate hashes") + } + return nil + } + + hashes := make([]multihash.DecodedMultihash, 0, len(msg.CertHashes)) + for _, h := range msg.CertHashes { + dh, err := multihash.Decode(h) + if err != nil { + return fmt.Errorf("failed to decode hash: %w", err) + } + hashes = append(hashes, *dh) + } + return l.certManager.Verify(hashes) +} + +func (l *listener) Addr() net.Addr { + return l.addr +} + +func (l *listener) Multiaddr() ma.Multiaddr { + if l.certManager == nil { + return l.multiaddr + } + return l.multiaddr.Encapsulate(l.certManager.AddrComponent()) +} + +func (l *listener) Close() error { + l.ctxCancel() + err := l.server.Close() + <-l.serverClosed + return err +} diff --git a/p2p/transport/webtransport/mock_connection_gater_test.go b/p2p/transport/webtransport/mock_connection_gater_test.go new file mode 100644 index 0000000000..c6e7dbaad6 --- /dev/null +++ b/p2p/transport/webtransport/mock_connection_gater_test.go @@ -0,0 +1,109 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: github.com/libp2p/go-libp2p/core/connmgr (interfaces: ConnectionGater) + +// Package libp2pwebtransport_test is a generated GoMock package. +package libp2pwebtransport_test + +import ( + reflect "reflect" + + gomock "github.com/golang/mock/gomock" + control "github.com/libp2p/go-libp2p/core/control" + network "github.com/libp2p/go-libp2p/core/network" + peer "github.com/libp2p/go-libp2p/core/peer" + multiaddr "github.com/multiformats/go-multiaddr" +) + +// MockConnectionGater is a mock of ConnectionGater interface. +type MockConnectionGater struct { + ctrl *gomock.Controller + recorder *MockConnectionGaterMockRecorder +} + +// MockConnectionGaterMockRecorder is the mock recorder for MockConnectionGater. +type MockConnectionGaterMockRecorder struct { + mock *MockConnectionGater +} + +// NewMockConnectionGater creates a new mock instance. +func NewMockConnectionGater(ctrl *gomock.Controller) *MockConnectionGater { + mock := &MockConnectionGater{ctrl: ctrl} + mock.recorder = &MockConnectionGaterMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockConnectionGater) EXPECT() *MockConnectionGaterMockRecorder { + return m.recorder +} + +// InterceptAccept mocks base method. +func (m *MockConnectionGater) InterceptAccept(arg0 network.ConnMultiaddrs) bool { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "InterceptAccept", arg0) + ret0, _ := ret[0].(bool) + return ret0 +} + +// InterceptAccept indicates an expected call of InterceptAccept. +func (mr *MockConnectionGaterMockRecorder) InterceptAccept(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "InterceptAccept", reflect.TypeOf((*MockConnectionGater)(nil).InterceptAccept), arg0) +} + +// InterceptAddrDial mocks base method. +func (m *MockConnectionGater) InterceptAddrDial(arg0 peer.ID, arg1 multiaddr.Multiaddr) bool { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "InterceptAddrDial", arg0, arg1) + ret0, _ := ret[0].(bool) + return ret0 +} + +// InterceptAddrDial indicates an expected call of InterceptAddrDial. +func (mr *MockConnectionGaterMockRecorder) InterceptAddrDial(arg0, arg1 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "InterceptAddrDial", reflect.TypeOf((*MockConnectionGater)(nil).InterceptAddrDial), arg0, arg1) +} + +// InterceptPeerDial mocks base method. +func (m *MockConnectionGater) InterceptPeerDial(arg0 peer.ID) bool { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "InterceptPeerDial", arg0) + ret0, _ := ret[0].(bool) + return ret0 +} + +// InterceptPeerDial indicates an expected call of InterceptPeerDial. +func (mr *MockConnectionGaterMockRecorder) InterceptPeerDial(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "InterceptPeerDial", reflect.TypeOf((*MockConnectionGater)(nil).InterceptPeerDial), arg0) +} + +// InterceptSecured mocks base method. +func (m *MockConnectionGater) InterceptSecured(arg0 network.Direction, arg1 peer.ID, arg2 network.ConnMultiaddrs) bool { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "InterceptSecured", arg0, arg1, arg2) + ret0, _ := ret[0].(bool) + return ret0 +} + +// InterceptSecured indicates an expected call of InterceptSecured. +func (mr *MockConnectionGaterMockRecorder) InterceptSecured(arg0, arg1, arg2 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "InterceptSecured", reflect.TypeOf((*MockConnectionGater)(nil).InterceptSecured), arg0, arg1, arg2) +} + +// InterceptUpgraded mocks base method. +func (m *MockConnectionGater) InterceptUpgraded(arg0 network.Conn) (bool, control.DisconnectReason) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "InterceptUpgraded", arg0) + ret0, _ := ret[0].(bool) + ret1, _ := ret[1].(control.DisconnectReason) + return ret0, ret1 +} + +// InterceptUpgraded indicates an expected call of InterceptUpgraded. +func (mr *MockConnectionGaterMockRecorder) InterceptUpgraded(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "InterceptUpgraded", reflect.TypeOf((*MockConnectionGater)(nil).InterceptUpgraded), arg0) +} diff --git a/p2p/transport/webtransport/multiaddr.go b/p2p/transport/webtransport/multiaddr.go new file mode 100644 index 0000000000..7b6df2ba25 --- /dev/null +++ b/p2p/transport/webtransport/multiaddr.go @@ -0,0 +1,68 @@ +package libp2pwebtransport + +import ( + "errors" + "fmt" + "net" + "strconv" + + ma "github.com/multiformats/go-multiaddr" + mafmt "github.com/multiformats/go-multiaddr-fmt" + manet "github.com/multiformats/go-multiaddr/net" + "github.com/multiformats/go-multibase" + "github.com/multiformats/go-multihash" +) + +var webtransportMA = ma.StringCast("/quic/webtransport") + +var webtransportMatcher = mafmt.And(mafmt.IP, mafmt.Base(ma.P_UDP), mafmt.Base(ma.P_QUIC), mafmt.Base(ma.P_WEBTRANSPORT)) + +func toWebtransportMultiaddr(na net.Addr) (ma.Multiaddr, error) { + addr, err := manet.FromNetAddr(na) + if err != nil { + return nil, err + } + if _, err := addr.ValueForProtocol(ma.P_UDP); err != nil { + return nil, errors.New("not a UDP address") + } + return addr.Encapsulate(webtransportMA), nil +} + +func stringToWebtransportMultiaddr(str string) (ma.Multiaddr, error) { + host, portStr, err := net.SplitHostPort(str) + if err != nil { + return nil, err + } + port, err := strconv.ParseInt(portStr, 10, 32) + if err != nil { + return nil, err + } + ip := net.ParseIP(host) + if ip == nil { + return nil, errors.New("failed to parse IP") + } + return toWebtransportMultiaddr(&net.UDPAddr{IP: ip, Port: int(port)}) +} + +func extractCertHashes(addr ma.Multiaddr) ([]multihash.DecodedMultihash, error) { + certHashesStr := make([]string, 0, 2) + ma.ForEach(addr, func(c ma.Component) bool { + if c.Protocol().Code == ma.P_CERTHASH { + certHashesStr = append(certHashesStr, c.Value()) + } + return true + }) + certHashes := make([]multihash.DecodedMultihash, 0, len(certHashesStr)) + for _, s := range certHashesStr { + _, ch, err := multibase.Decode(s) + if err != nil { + return nil, fmt.Errorf("failed to multibase-decode certificate hash: %w", err) + } + dh, err := multihash.Decode(ch) + if err != nil { + return nil, fmt.Errorf("failed to multihash-decode certificate hash: %w", err) + } + certHashes = append(certHashes, *dh) + } + return certHashes, nil +} diff --git a/p2p/transport/webtransport/multiaddr_test.go b/p2p/transport/webtransport/multiaddr_test.go new file mode 100644 index 0000000000..3af1ce56e2 --- /dev/null +++ b/p2p/transport/webtransport/multiaddr_test.go @@ -0,0 +1,76 @@ +package libp2pwebtransport + +import ( + "fmt" + "net" + "testing" + + ma "github.com/multiformats/go-multiaddr" + "github.com/multiformats/go-multibase" + "github.com/multiformats/go-multihash" + "github.com/stretchr/testify/require" +) + +func TestWebtransportMultiaddr(t *testing.T) { + t.Run("valid", func(t *testing.T) { + addr, err := toWebtransportMultiaddr(&net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337}) + require.NoError(t, err) + require.Equal(t, "/ip4/127.0.0.1/udp/1337/quic/webtransport", addr.String()) + }) + + t.Run("invalid", func(t *testing.T) { + _, err := toWebtransportMultiaddr(&net.TCPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337}) + require.EqualError(t, err, "not a UDP address") + }) +} + +func TestWebtransportMultiaddrFromString(t *testing.T) { + t.Run("valid", func(t *testing.T) { + addr, err := stringToWebtransportMultiaddr("1.2.3.4:60042") + require.NoError(t, err) + require.Equal(t, "/ip4/1.2.3.4/udp/60042/quic/webtransport", addr.String()) + }) + + t.Run("invalid", func(t *testing.T) { + for _, addr := range [...]string{ + "1.2.3.4", // missing port + "1.2.3.4:123456", // invalid port + ":1234", // missing IP + "foobar", + } { + _, err := stringToWebtransportMultiaddr(addr) + require.Error(t, err) + } + }) +} + +func encodeCertHash(t *testing.T, b []byte, mh uint64, mb multibase.Encoding) string { + t.Helper() + h, err := multihash.Encode(b, mh) + require.NoError(t, err) + str, err := multibase.Encode(mb, h) + require.NoError(t, err) + return str +} + +func TestExtractCertHashes(t *testing.T) { + fooHash := encodeCertHash(t, []byte("foo"), multihash.SHA2_256, multibase.Base58BTC) + barHash := encodeCertHash(t, []byte("bar"), multihash.BLAKE2B_MAX, multibase.Base32) + + // valid cases + for _, tc := range [...]struct { + addr string + hashes []string + }{ + {addr: "/ip4/127.0.0.1/udp/1234/quic/webtransport"}, + {addr: fmt.Sprintf("/ip4/127.0.0.1/udp/1234/quic/webtransport/certhash/%s", fooHash), hashes: []string{"foo"}}, + {addr: fmt.Sprintf("/ip4/127.0.0.1/udp/1234/quic/webtransport/certhash/%s/certhash/%s", fooHash, barHash), hashes: []string{"foo", "bar"}}, + } { + ch, err := extractCertHashes(ma.StringCast(tc.addr)) + require.NoError(t, err) + require.Len(t, ch, len(tc.hashes)) + for i, h := range tc.hashes { + require.Equal(t, h, string(ch[i].Digest)) + } + } +} diff --git a/p2p/transport/webtransport/noise_early_data.go b/p2p/transport/webtransport/noise_early_data.go new file mode 100644 index 0000000000..ec01c6d7a2 --- /dev/null +++ b/p2p/transport/webtransport/noise_early_data.go @@ -0,0 +1,34 @@ +package libp2pwebtransport + +import ( + "context" + "github.com/libp2p/go-libp2p/core/peer" + "github.com/libp2p/go-libp2p/p2p/security/noise" + "net" +) + +type earlyDataHandler struct { + earlyData []byte + receive func([]byte) error +} + +var _ noise.EarlyDataHandler = &earlyDataHandler{} + +func newEarlyDataSender(earlyData []byte) noise.EarlyDataHandler { + return &earlyDataHandler{earlyData: earlyData} +} + +func newEarlyDataReceiver(receive func([]byte) error) noise.EarlyDataHandler { + return &earlyDataHandler{receive: receive} +} + +func (e *earlyDataHandler) Send(context.Context, net.Conn, peer.ID) []byte { + return e.earlyData +} + +func (e *earlyDataHandler) Received(_ context.Context, _ net.Conn, b []byte) error { + if e.receive == nil { + return nil + } + return e.receive(b) +} diff --git a/p2p/transport/webtransport/pb/Makefile b/p2p/transport/webtransport/pb/Makefile new file mode 100644 index 0000000000..8af2dd8177 --- /dev/null +++ b/p2p/transport/webtransport/pb/Makefile @@ -0,0 +1,11 @@ +PB = $(wildcard *.proto) +GO = $(PB:.proto=.pb.go) + +all: $(GO) + +%.pb.go: %.proto + protoc --proto_path=$(PWD)/../..:. --gogofaster_out=. $< + +clean: + rm -f *.pb.go + rm -f *.go diff --git a/p2p/transport/webtransport/pb/webtransport.pb.go b/p2p/transport/webtransport/pb/webtransport.pb.go new file mode 100644 index 0000000000..810d0667c9 --- /dev/null +++ b/p2p/transport/webtransport/pb/webtransport.pb.go @@ -0,0 +1,315 @@ +// Code generated by protoc-gen-gogo. DO NOT EDIT. +// source: webtransport.proto + +package webtransport + +import ( + fmt "fmt" + proto "github.com/gogo/protobuf/proto" + io "io" + math "math" + math_bits "math/bits" +) + +// Reference imports to suppress errors if they are not otherwise used. +var _ = proto.Marshal +var _ = fmt.Errorf +var _ = math.Inf + +// This is a compile-time assertion to ensure that this generated file +// is compatible with the proto package it is being compiled against. +// A compilation error at this line likely means your copy of the +// proto package needs to be updated. +const _ = proto.GoGoProtoPackageIsVersion3 // please upgrade the proto package + +type WebTransport struct { + CertHashes [][]byte `protobuf:"bytes,1,rep,name=cert_hashes,json=certHashes" json:"cert_hashes,omitempty"` +} + +func (m *WebTransport) Reset() { *m = WebTransport{} } +func (m *WebTransport) String() string { return proto.CompactTextString(m) } +func (*WebTransport) ProtoMessage() {} +func (*WebTransport) Descriptor() ([]byte, []int) { + return fileDescriptor_db878920ab41a4f3, []int{0} +} +func (m *WebTransport) XXX_Unmarshal(b []byte) error { + return m.Unmarshal(b) +} +func (m *WebTransport) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) { + if deterministic { + return xxx_messageInfo_WebTransport.Marshal(b, m, deterministic) + } else { + b = b[:cap(b)] + n, err := m.MarshalToSizedBuffer(b) + if err != nil { + return nil, err + } + return b[:n], nil + } +} +func (m *WebTransport) XXX_Merge(src proto.Message) { + xxx_messageInfo_WebTransport.Merge(m, src) +} +func (m *WebTransport) XXX_Size() int { + return m.Size() +} +func (m *WebTransport) XXX_DiscardUnknown() { + xxx_messageInfo_WebTransport.DiscardUnknown(m) +} + +var xxx_messageInfo_WebTransport proto.InternalMessageInfo + +func (m *WebTransport) GetCertHashes() [][]byte { + if m != nil { + return m.CertHashes + } + return nil +} + +func init() { + proto.RegisterType((*WebTransport)(nil), "WebTransport") +} + +func init() { proto.RegisterFile("webtransport.proto", fileDescriptor_db878920ab41a4f3) } + +var fileDescriptor_db878920ab41a4f3 = []byte{ + // 109 bytes of a gzipped FileDescriptorProto + 0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xff, 0xe2, 0x12, 0x2a, 0x4f, 0x4d, 0x2a, + 0x29, 0x4a, 0xcc, 0x2b, 0x2e, 0xc8, 0x2f, 0x2a, 0xd1, 0x2b, 0x28, 0xca, 0x2f, 0xc9, 0x57, 0xd2, + 0xe7, 0xe2, 0x09, 0x4f, 0x4d, 0x0a, 0x81, 0x89, 0x0a, 0xc9, 0x73, 0x71, 0x27, 0xa7, 0x16, 0x95, + 0xc4, 0x67, 0x24, 0x16, 0x67, 0xa4, 0x16, 0x4b, 0x30, 0x2a, 0x30, 0x6b, 0xf0, 0x04, 0x71, 0x81, + 0x84, 0x3c, 0xc0, 0x22, 0x4e, 0x12, 0x27, 0x1e, 0xc9, 0x31, 0x5e, 0x78, 0x24, 0xc7, 0xf8, 0xe0, + 0x91, 0x1c, 0xe3, 0x84, 0xc7, 0x72, 0x0c, 0x17, 0x1e, 0xcb, 0x31, 0xdc, 0x78, 0x2c, 0xc7, 0x00, + 0x08, 0x00, 0x00, 0xff, 0xff, 0x50, 0x77, 0xe5, 0x52, 0x5f, 0x00, 0x00, 0x00, +} + +func (m *WebTransport) Marshal() (dAtA []byte, err error) { + size := m.Size() + dAtA = make([]byte, size) + n, err := m.MarshalToSizedBuffer(dAtA[:size]) + if err != nil { + return nil, err + } + return dAtA[:n], nil +} + +func (m *WebTransport) MarshalTo(dAtA []byte) (int, error) { + size := m.Size() + return m.MarshalToSizedBuffer(dAtA[:size]) +} + +func (m *WebTransport) MarshalToSizedBuffer(dAtA []byte) (int, error) { + i := len(dAtA) + _ = i + var l int + _ = l + if len(m.CertHashes) > 0 { + for iNdEx := len(m.CertHashes) - 1; iNdEx >= 0; iNdEx-- { + i -= len(m.CertHashes[iNdEx]) + copy(dAtA[i:], m.CertHashes[iNdEx]) + i = encodeVarintWebtransport(dAtA, i, uint64(len(m.CertHashes[iNdEx]))) + i-- + dAtA[i] = 0xa + } + } + return len(dAtA) - i, nil +} + +func encodeVarintWebtransport(dAtA []byte, offset int, v uint64) int { + offset -= sovWebtransport(v) + base := offset + for v >= 1<<7 { + dAtA[offset] = uint8(v&0x7f | 0x80) + v >>= 7 + offset++ + } + dAtA[offset] = uint8(v) + return base +} +func (m *WebTransport) Size() (n int) { + if m == nil { + return 0 + } + var l int + _ = l + if len(m.CertHashes) > 0 { + for _, b := range m.CertHashes { + l = len(b) + n += 1 + l + sovWebtransport(uint64(l)) + } + } + return n +} + +func sovWebtransport(x uint64) (n int) { + return (math_bits.Len64(x|1) + 6) / 7 +} +func sozWebtransport(x uint64) (n int) { + return sovWebtransport(uint64((x << 1) ^ uint64((int64(x) >> 63)))) +} +func (m *WebTransport) Unmarshal(dAtA []byte) error { + l := len(dAtA) + iNdEx := 0 + for iNdEx < l { + preIndex := iNdEx + var wire uint64 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowWebtransport + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + wire |= uint64(b&0x7F) << shift + if b < 0x80 { + break + } + } + fieldNum := int32(wire >> 3) + wireType := int(wire & 0x7) + if wireType == 4 { + return fmt.Errorf("proto: WebTransport: wiretype end group for non-group") + } + if fieldNum <= 0 { + return fmt.Errorf("proto: WebTransport: illegal tag %d (wire type %d)", fieldNum, wire) + } + switch fieldNum { + case 1: + if wireType != 2 { + return fmt.Errorf("proto: wrong wireType = %d for field CertHashes", wireType) + } + var byteLen int + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowWebtransport + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + byteLen |= int(b&0x7F) << shift + if b < 0x80 { + break + } + } + if byteLen < 0 { + return ErrInvalidLengthWebtransport + } + postIndex := iNdEx + byteLen + if postIndex < 0 { + return ErrInvalidLengthWebtransport + } + if postIndex > l { + return io.ErrUnexpectedEOF + } + m.CertHashes = append(m.CertHashes, make([]byte, postIndex-iNdEx)) + copy(m.CertHashes[len(m.CertHashes)-1], dAtA[iNdEx:postIndex]) + iNdEx = postIndex + default: + iNdEx = preIndex + skippy, err := skipWebtransport(dAtA[iNdEx:]) + if err != nil { + return err + } + if (skippy < 0) || (iNdEx+skippy) < 0 { + return ErrInvalidLengthWebtransport + } + if (iNdEx + skippy) > l { + return io.ErrUnexpectedEOF + } + iNdEx += skippy + } + } + + if iNdEx > l { + return io.ErrUnexpectedEOF + } + return nil +} +func skipWebtransport(dAtA []byte) (n int, err error) { + l := len(dAtA) + iNdEx := 0 + depth := 0 + for iNdEx < l { + var wire uint64 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return 0, ErrIntOverflowWebtransport + } + if iNdEx >= l { + return 0, io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + wire |= (uint64(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + wireType := int(wire & 0x7) + switch wireType { + case 0: + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return 0, ErrIntOverflowWebtransport + } + if iNdEx >= l { + return 0, io.ErrUnexpectedEOF + } + iNdEx++ + if dAtA[iNdEx-1] < 0x80 { + break + } + } + case 1: + iNdEx += 8 + case 2: + var length int + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return 0, ErrIntOverflowWebtransport + } + if iNdEx >= l { + return 0, io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + length |= (int(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + if length < 0 { + return 0, ErrInvalidLengthWebtransport + } + iNdEx += length + case 3: + depth++ + case 4: + if depth == 0 { + return 0, ErrUnexpectedEndOfGroupWebtransport + } + depth-- + case 5: + iNdEx += 4 + default: + return 0, fmt.Errorf("proto: illegal wireType %d", wireType) + } + if iNdEx < 0 { + return 0, ErrInvalidLengthWebtransport + } + if depth == 0 { + return iNdEx, nil + } + } + return 0, io.ErrUnexpectedEOF +} + +var ( + ErrInvalidLengthWebtransport = fmt.Errorf("proto: negative length found during unmarshaling") + ErrIntOverflowWebtransport = fmt.Errorf("proto: integer overflow") + ErrUnexpectedEndOfGroupWebtransport = fmt.Errorf("proto: unexpected end of group") +) diff --git a/p2p/transport/webtransport/pb/webtransport.proto b/p2p/transport/webtransport/pb/webtransport.proto new file mode 100644 index 0000000000..a9129ce6c6 --- /dev/null +++ b/p2p/transport/webtransport/pb/webtransport.proto @@ -0,0 +1,5 @@ +syntax = "proto2"; + +message WebTransport { + repeated bytes cert_hashes = 1; +} diff --git a/p2p/transport/webtransport/stream.go b/p2p/transport/webtransport/stream.go new file mode 100644 index 0000000000..ff17b3083f --- /dev/null +++ b/p2p/transport/webtransport/stream.go @@ -0,0 +1,71 @@ +package libp2pwebtransport + +import ( + "errors" + "net" + + "github.com/marten-seemann/webtransport-go" + + "github.com/libp2p/go-libp2p/core/network" +) + +const ( + reset webtransport.ErrorCode = 0 +) + +type webtransportStream struct { + webtransport.Stream + wsess *webtransport.Session +} + +var _ net.Conn = &webtransportStream{} + +func (s *webtransportStream) LocalAddr() net.Addr { + return s.wsess.LocalAddr() +} + +func (s *webtransportStream) RemoteAddr() net.Addr { + return s.wsess.RemoteAddr() +} + +type stream struct { + webtransport.Stream +} + +var _ network.MuxedStream = &stream{} + +func (s *stream) Read(b []byte) (n int, err error) { + n, err = s.Stream.Read(b) + if err != nil && errors.Is(err, &webtransport.StreamError{}) { + err = network.ErrReset + } + return n, err +} + +func (s *stream) Write(b []byte) (n int, err error) { + n, err = s.Stream.Write(b) + if err != nil && errors.Is(err, &webtransport.StreamError{}) { + err = network.ErrReset + } + return n, err +} + +func (s *stream) Reset() error { + s.Stream.CancelRead(reset) + s.Stream.CancelWrite(reset) + return nil +} + +func (s *stream) Close() error { + s.Stream.CancelRead(reset) + return s.Stream.Close() +} + +func (s *stream) CloseRead() error { + s.Stream.CancelRead(reset) + return nil +} + +func (s *stream) CloseWrite() error { + return s.Stream.Close() +} diff --git a/p2p/transport/webtransport/transport.go b/p2p/transport/webtransport/transport.go new file mode 100644 index 0000000000..faa13db86c --- /dev/null +++ b/p2p/transport/webtransport/transport.go @@ -0,0 +1,268 @@ +package libp2pwebtransport + +import ( + "context" + "crypto/tls" + "crypto/x509" + "fmt" + "io" + "sync" + "time" + + pb "github.com/marten-seemann/go-libp2p-webtransport/pb" + + "github.com/libp2p/go-libp2p/core/connmgr" + ic "github.com/libp2p/go-libp2p/core/crypto" + "github.com/libp2p/go-libp2p/core/network" + "github.com/libp2p/go-libp2p/core/peer" + tpt "github.com/libp2p/go-libp2p/core/transport" + "github.com/libp2p/go-libp2p/p2p/security/noise" + + "github.com/benbjohnson/clock" + logging "github.com/ipfs/go-log/v2" + "github.com/lucas-clemente/quic-go/http3" + "github.com/marten-seemann/webtransport-go" + ma "github.com/multiformats/go-multiaddr" + manet "github.com/multiformats/go-multiaddr/net" + "github.com/multiformats/go-multihash" +) + +var log = logging.Logger("webtransport") + +const webtransportHTTPEndpoint = "/.well-known/libp2p-webtransport" + +const certValidity = 14 * 24 * time.Hour + +type Option func(*transport) error + +func WithClock(cl clock.Clock) Option { + return func(t *transport) error { + t.clock = cl + return nil + } +} + +// WithTLSConfig sets a tls.Config used for listening. +// When used, the certificate from that config will be used, and no /certhash will be added to the listener's multiaddr. +// This is most useful when running a listener that has a valid (CA-signed) certificate. +func WithTLSConfig(c *tls.Config) Option { + return func(t *transport) error { + t.staticTLSConf = c + return nil + } +} + +// WithTLSClientConfig sets a custom tls.Config used for dialing. +// This option is most useful for setting a custom tls.Config.RootCAs certificate pool. +// When dialing a multiaddr that contains a /certhash component, this library will set InsecureSkipVerify and +// overwrite the VerifyPeerCertificate callback. +func WithTLSClientConfig(c *tls.Config) Option { + return func(t *transport) error { + t.tlsClientConf = c + return nil + } +} + +type transport struct { + privKey ic.PrivKey + pid peer.ID + clock clock.Clock + + rcmgr network.ResourceManager + gater connmgr.ConnectionGater + + listenOnce sync.Once + listenOnceErr error + certManager *certManager + staticTLSConf *tls.Config + tlsClientConf *tls.Config + + noise *noise.Transport +} + +var _ tpt.Transport = &transport{} +var _ io.Closer = &transport{} + +func New(key ic.PrivKey, gater connmgr.ConnectionGater, rcmgr network.ResourceManager, opts ...Option) (tpt.Transport, error) { + id, err := peer.IDFromPrivateKey(key) + if err != nil { + return nil, err + } + t := &transport{ + pid: id, + privKey: key, + rcmgr: rcmgr, + gater: gater, + clock: clock.New(), + } + for _, opt := range opts { + if err := opt(t); err != nil { + return nil, err + } + } + n, err := noise.New(key) + if err != nil { + return nil, err + } + t.noise = n + return t, nil +} + +func (t *transport) Dial(ctx context.Context, raddr ma.Multiaddr, p peer.ID) (tpt.CapableConn, error) { + _, addr, err := manet.DialArgs(raddr) + if err != nil { + return nil, err + } + certHashes, err := extractCertHashes(raddr) + if err != nil { + return nil, err + } + + scope, err := t.rcmgr.OpenConnection(network.DirOutbound, false, raddr) + if err != nil { + log.Debugw("resource manager blocked outgoing connection", "peer", p, "addr", raddr, "error", err) + return nil, err + } + if err := scope.SetPeer(p); err != nil { + log.Debugw("resource manager blocked outgoing connection for peer", "peer", p, "addr", raddr, "error", err) + scope.Done() + return nil, err + } + + sess, err := t.dial(ctx, addr, certHashes) + if err != nil { + scope.Done() + return nil, err + } + sconn, err := t.upgrade(ctx, sess, p, certHashes) + if err != nil { + sess.Close() + scope.Done() + return nil, err + } + if t.gater != nil && !t.gater.InterceptSecured(network.DirOutbound, p, sconn) { + // TODO: can we close with a specific error here? + sess.Close() + return nil, fmt.Errorf("secured connection gated") + } + + return newConn(t, sess, sconn, scope), nil +} + +func (t *transport) dial(ctx context.Context, addr string, certHashes []multihash.DecodedMultihash) (*webtransport.Session, error) { + url := fmt.Sprintf("https://%s%s", addr, webtransportHTTPEndpoint) + var tlsConf *tls.Config + if t.tlsClientConf != nil { + tlsConf = t.tlsClientConf.Clone() + } else { + tlsConf = &tls.Config{} + } + + if len(certHashes) > 0 { + tlsConf.InsecureSkipVerify = true // this is not insecure. We verify the certificate ourselves. + tlsConf.VerifyPeerCertificate = func(rawCerts [][]byte, _ [][]*x509.Certificate) error { + return verifyRawCerts(rawCerts, certHashes) + } + } + dialer := webtransport.Dialer{ + RoundTripper: &http3.RoundTripper{TLSClientConfig: tlsConf}, + } + rsp, sess, err := dialer.Dial(ctx, url, nil) + if err != nil { + return nil, err + } + if rsp.StatusCode < 200 || rsp.StatusCode > 299 { + return nil, fmt.Errorf("invalid response status code: %d", rsp.StatusCode) + } + return sess, err +} + +func (t *transport) upgrade(ctx context.Context, sess *webtransport.Session, p peer.ID, certHashes []multihash.DecodedMultihash) (connSecurityMultiaddrs, error) { + local, err := toWebtransportMultiaddr(sess.LocalAddr()) + if err != nil { + return nil, fmt.Errorf("error determiniting local addr: %w", err) + } + remote, err := toWebtransportMultiaddr(sess.RemoteAddr()) + if err != nil { + return nil, fmt.Errorf("error determiniting remote addr: %w", err) + } + + str, err := sess.OpenStreamSync(ctx) + if err != nil { + return nil, err + } + + // Now run a Noise handshake (using early data) and verify the cert hash. + msg := pb.WebTransport{CertHashes: make([][]byte, 0, len(certHashes))} + for _, certHash := range certHashes { + h, err := multihash.Encode(certHash.Digest, certHash.Code) + if err != nil { + return nil, fmt.Errorf("failed to encode certificate hash: %w", err) + } + msg.CertHashes = append(msg.CertHashes, h) + } + msgBytes, err := msg.Marshal() + if err != nil { + return nil, fmt.Errorf("failed to marshal WebTransport protobuf: %w", err) + } + n, err := t.noise.WithSessionOptions(noise.EarlyData(newEarlyDataSender(msgBytes))) + if err != nil { + return nil, fmt.Errorf("failed to create Noise transport: %w", err) + } + c, err := n.SecureOutbound(ctx, &webtransportStream{Stream: str, wsess: sess}, p) + if err != nil { + return nil, err + } + return &connSecurityMultiaddrsImpl{ + ConnSecurity: c, + ConnMultiaddrs: &connMultiaddrs{local: local, remote: remote}, + }, nil +} + +func (t *transport) CanDial(addr ma.Multiaddr) bool { + var numHashes int + ma.ForEach(addr, func(c ma.Component) bool { + if c.Protocol().Code == ma.P_CERTHASH { + numHashes++ + } + return true + }) + if numHashes == 0 { + return false + } + for i := 0; i < numHashes; i++ { + addr, _ = ma.SplitLast(addr) + } + return webtransportMatcher.Matches(addr) +} + +func (t *transport) Listen(laddr ma.Multiaddr) (tpt.Listener, error) { + if !webtransportMatcher.Matches(laddr) { + return nil, fmt.Errorf("cannot listen on non-WebTransport addr: %s", laddr) + } + if t.staticTLSConf == nil { + t.listenOnce.Do(func() { + t.certManager, t.listenOnceErr = newCertManager(t.clock) + }) + if t.listenOnceErr != nil { + return nil, t.listenOnceErr + } + } + return newListener(laddr, t, t.noise, t.certManager, t.staticTLSConf, t.gater, t.rcmgr) +} + +func (t *transport) Protocols() []int { + return []int{ma.P_WEBTRANSPORT} +} + +func (t *transport) Proxy() bool { + return false +} + +func (t *transport) Close() error { + t.listenOnce.Do(func() {}) + if t.certManager != nil { + return t.certManager.Close() + } + return nil +} diff --git a/p2p/transport/webtransport/transport_test.go b/p2p/transport/webtransport/transport_test.go new file mode 100644 index 0000000000..02f065bc8c --- /dev/null +++ b/p2p/transport/webtransport/transport_test.go @@ -0,0 +1,512 @@ +package libp2pwebtransport_test + +import ( + "context" + "crypto/ecdsa" + "crypto/elliptic" + "crypto/rand" + "crypto/sha256" + "crypto/tls" + "crypto/x509" + "crypto/x509/pkix" + "errors" + "fmt" + "io" + "math/big" + "net" + "strings" + "testing" + "time" + + libp2pwebtransport "github.com/marten-seemann/go-libp2p-webtransport" + + ic "github.com/libp2p/go-libp2p/core/crypto" + "github.com/libp2p/go-libp2p/core/network" + mocknetwork "github.com/libp2p/go-libp2p/core/network/mocks" + "github.com/libp2p/go-libp2p/core/peer" + + "github.com/golang/mock/gomock" + ma "github.com/multiformats/go-multiaddr" + manet "github.com/multiformats/go-multiaddr/net" + "github.com/multiformats/go-multibase" + "github.com/multiformats/go-multihash" + + "github.com/stretchr/testify/require" +) + +func newIdentity(t *testing.T) (peer.ID, ic.PrivKey) { + key, _, err := ic.GenerateEd25519Key(rand.Reader) + require.NoError(t, err) + id, err := peer.IDFromPrivateKey(key) + require.NoError(t, err) + return id, key +} + +func randomMultihash(t *testing.T) string { + t.Helper() + b := make([]byte, 16) + rand.Read(b) + h, err := multihash.Encode(b, multihash.KECCAK_224) + require.NoError(t, err) + s, err := multibase.Encode(multibase.Base32hex, h) + require.NoError(t, err) + return s +} + +func extractCertHashes(addr ma.Multiaddr) []string { + var certHashesStr []string + ma.ForEach(addr, func(c ma.Component) bool { + if c.Protocol().Code == ma.P_CERTHASH { + certHashesStr = append(certHashesStr, c.Value()) + } + return true + }) + return certHashesStr +} + +func stripCertHashes(addr ma.Multiaddr) ma.Multiaddr { + for { + _, err := addr.ValueForProtocol(ma.P_CERTHASH) + if err != nil { + return addr + } + addr, _ = ma.SplitLast(addr) + } +} + +// create a /certhash multiaddr component using the SHA256 of foobar +func getCerthashComponent(t *testing.T, b []byte) ma.Multiaddr { + t.Helper() + h := sha256.Sum256(b) + mh, err := multihash.Encode(h[:], multihash.SHA2_256) + require.NoError(t, err) + certStr, err := multibase.Encode(multibase.Base58BTC, mh) + require.NoError(t, err) + ha, err := ma.NewComponent(ma.ProtocolWithCode(ma.P_CERTHASH).Name, certStr) + require.NoError(t, err) + return ha +} + +func TestTransport(t *testing.T) { + serverID, serverKey := newIdentity(t) + tr, err := libp2pwebtransport.New(serverKey, nil, network.NullResourceManager) + require.NoError(t, err) + defer tr.(io.Closer).Close() + ln, err := tr.Listen(ma.StringCast("/ip4/127.0.0.1/udp/0/quic/webtransport")) + require.NoError(t, err) + defer ln.Close() + + addrChan := make(chan ma.Multiaddr) + go func() { + _, clientKey := newIdentity(t) + tr2, err := libp2pwebtransport.New(clientKey, nil, network.NullResourceManager) + require.NoError(t, err) + defer tr2.(io.Closer).Close() + + conn, err := tr2.Dial(context.Background(), ln.Multiaddr(), serverID) + require.NoError(t, err) + str, err := conn.OpenStream(context.Background()) + require.NoError(t, err) + _, err = str.Write([]byte("foobar")) + require.NoError(t, err) + require.NoError(t, str.Close()) + + // check RemoteMultiaddr + _, addr, err := manet.DialArgs(ln.Multiaddr()) + require.NoError(t, err) + _, port, err := net.SplitHostPort(addr) + require.NoError(t, err) + require.Equal(t, ma.StringCast(fmt.Sprintf("/ip4/127.0.0.1/udp/%s/quic/webtransport", port)), conn.RemoteMultiaddr()) + addrChan <- conn.RemoteMultiaddr() + }() + + conn, err := ln.Accept() + require.NoError(t, err) + require.False(t, conn.IsClosed()) + str, err := conn.AcceptStream() + require.NoError(t, err) + data, err := io.ReadAll(str) + require.NoError(t, err) + require.Equal(t, "foobar", string(data)) + require.Equal(t, <-addrChan, conn.LocalMultiaddr()) + require.NoError(t, conn.Close()) + require.True(t, conn.IsClosed()) +} + +func TestHashVerification(t *testing.T) { + serverID, serverKey := newIdentity(t) + tr, err := libp2pwebtransport.New(serverKey, nil, network.NullResourceManager) + require.NoError(t, err) + defer tr.(io.Closer).Close() + ln, err := tr.Listen(ma.StringCast("/ip4/127.0.0.1/udp/0/quic/webtransport")) + require.NoError(t, err) + done := make(chan struct{}) + go func() { + defer close(done) + _, err := ln.Accept() + require.Error(t, err) + }() + + _, clientKey := newIdentity(t) + tr2, err := libp2pwebtransport.New(clientKey, nil, network.NullResourceManager) + require.NoError(t, err) + defer tr2.(io.Closer).Close() + + foobarHash := getCerthashComponent(t, []byte("foobar")) + + t.Run("fails using only a wrong hash", func(t *testing.T) { + // replace the certificate hash in the multiaddr with a fake hash + addr := stripCertHashes(ln.Multiaddr()).Encapsulate(foobarHash) + _, err := tr2.Dial(context.Background(), addr, serverID) + require.Error(t, err) + require.Contains(t, err.Error(), "CRYPTO_ERROR (0x12a): cert hash not found") + }) + + t.Run("fails when adding a wrong hash", func(t *testing.T) { + _, err := tr2.Dial(context.Background(), ln.Multiaddr().Encapsulate(foobarHash), serverID) + require.Error(t, err) + }) + + require.NoError(t, ln.Close()) + <-done +} + +func TestCanDial(t *testing.T) { + valid := []ma.Multiaddr{ + ma.StringCast("/ip4/127.0.0.1/udp/1234/quic/webtransport/certhash/" + randomMultihash(t)), + ma.StringCast("/ip6/b16b:8255:efc6:9cd5:1a54:ee86:2d7a:c2e6/udp/1234/quic/webtransport/certhash/" + randomMultihash(t)), + ma.StringCast(fmt.Sprintf("/ip4/127.0.0.1/udp/1234/quic/webtransport/certhash/%s/certhash/%s/certhash/%s", randomMultihash(t), randomMultihash(t), randomMultihash(t))), + } + + invalid := []ma.Multiaddr{ + ma.StringCast("/ip4/127.0.0.1/udp/1234"), // missing webtransport + ma.StringCast("/ip4/127.0.0.1/udp/1234/webtransport"), // missing quic + ma.StringCast("/ip4/127.0.0.1/tcp/1234/webtransport"), // WebTransport over TCP? Is this a joke? + ma.StringCast("/ip4/127.0.0.1/udp/1234/quic/webtransport"), // missing certificate hash + } + + _, key := newIdentity(t) + tr, err := libp2pwebtransport.New(key, nil, network.NullResourceManager) + require.NoError(t, err) + defer tr.(io.Closer).Close() + + for _, addr := range valid { + require.Truef(t, tr.CanDial(addr), "expected to be able to dial %s", addr) + } + for _, addr := range invalid { + require.Falsef(t, tr.CanDial(addr), "expected to not be able to dial %s", addr) + } +} + +func TestListenAddrValidity(t *testing.T) { + valid := []ma.Multiaddr{ + ma.StringCast("/ip6/::/udp/0/quic/webtransport/"), + ma.StringCast("/ip4/127.0.0.1/udp/1234/quic/webtransport/"), + } + + invalid := []ma.Multiaddr{ + ma.StringCast("/ip4/127.0.0.1/udp/1234"), // missing webtransport + ma.StringCast("/ip4/127.0.0.1/udp/1234/webtransport"), // missing quic + ma.StringCast("/ip4/127.0.0.1/tcp/1234/webtransport"), // WebTransport over TCP? Is this a joke? + ma.StringCast("/ip4/127.0.0.1/udp/1234/quic/webtransport/certhash/" + randomMultihash(t)), + } + + _, key := newIdentity(t) + tr, err := libp2pwebtransport.New(key, nil, network.NullResourceManager) + require.NoError(t, err) + defer tr.(io.Closer).Close() + + for _, addr := range valid { + ln, err := tr.Listen(addr) + require.NoErrorf(t, err, "expected to be able to listen on %s", addr) + ln.Close() + } + for _, addr := range invalid { + _, err := tr.Listen(addr) + require.Errorf(t, err, "expected to not be able to listen on %s", addr) + } +} + +func TestListenerAddrs(t *testing.T) { + _, key := newIdentity(t) + tr, err := libp2pwebtransport.New(key, nil, network.NullResourceManager) + require.NoError(t, err) + defer tr.(io.Closer).Close() + + ln1, err := tr.Listen(ma.StringCast("/ip4/127.0.0.1/udp/0/quic/webtransport")) + require.NoError(t, err) + ln2, err := tr.Listen(ma.StringCast("/ip4/127.0.0.1/udp/0/quic/webtransport")) + require.NoError(t, err) + hashes1 := extractCertHashes(ln1.Multiaddr()) + require.Len(t, hashes1, 2) + hashes2 := extractCertHashes(ln2.Multiaddr()) + require.Equal(t, hashes1, hashes2) +} + +func TestResourceManagerDialing(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + rcmgr := mocknetwork.NewMockResourceManager(ctrl) + + addr := ma.StringCast("/ip4/9.8.7.6/udp/1234/quic/webtransport") + p := peer.ID("foobar") + + _, key := newIdentity(t) + tr, err := libp2pwebtransport.New(key, nil, rcmgr) + require.NoError(t, err) + defer tr.(io.Closer).Close() + + scope := mocknetwork.NewMockConnManagementScope(ctrl) + rcmgr.EXPECT().OpenConnection(network.DirOutbound, false, addr).Return(scope, nil) + scope.EXPECT().SetPeer(p).Return(errors.New("denied")) + scope.EXPECT().Done() + + _, err = tr.Dial(context.Background(), addr, p) + require.EqualError(t, err, "denied") +} + +func TestResourceManagerListening(t *testing.T) { + clientID, key := newIdentity(t) + cl, err := libp2pwebtransport.New(key, nil, network.NullResourceManager) + require.NoError(t, err) + defer cl.(io.Closer).Close() + + t.Run("blocking the connection", func(t *testing.T) { + serverID, key := newIdentity(t) + ctrl := gomock.NewController(t) + defer ctrl.Finish() + rcmgr := mocknetwork.NewMockResourceManager(ctrl) + tr, err := libp2pwebtransport.New(key, nil, rcmgr) + require.NoError(t, err) + ln, err := tr.Listen(ma.StringCast("/ip4/127.0.0.1/udp/0/quic/webtransport")) + require.NoError(t, err) + defer ln.Close() + + rcmgr.EXPECT().OpenConnection(network.DirInbound, false, gomock.Any()).DoAndReturn(func(_ network.Direction, _ bool, addr ma.Multiaddr) (network.ConnManagementScope, error) { + _, err := addr.ValueForProtocol(ma.P_WEBTRANSPORT) + require.NoError(t, err, "expected a WebTransport multiaddr") + _, addrStr, err := manet.DialArgs(addr) + require.NoError(t, err) + host, _, err := net.SplitHostPort(addrStr) + require.NoError(t, err) + require.Equal(t, "127.0.0.1", host) + return nil, errors.New("denied") + }) + + _, err = cl.Dial(context.Background(), ln.Multiaddr(), serverID) + require.EqualError(t, err, "received status 503") + }) + + t.Run("blocking the peer", func(t *testing.T) { + serverID, key := newIdentity(t) + ctrl := gomock.NewController(t) + defer ctrl.Finish() + rcmgr := mocknetwork.NewMockResourceManager(ctrl) + tr, err := libp2pwebtransport.New(key, nil, rcmgr) + require.NoError(t, err) + ln, err := tr.Listen(ma.StringCast("/ip4/127.0.0.1/udp/0/quic/webtransport")) + require.NoError(t, err) + defer ln.Close() + + serverDone := make(chan struct{}) + scope := mocknetwork.NewMockConnManagementScope(ctrl) + rcmgr.EXPECT().OpenConnection(network.DirInbound, false, gomock.Any()).Return(scope, nil) + scope.EXPECT().SetPeer(clientID).Return(errors.New("denied")) + scope.EXPECT().Done().Do(func() { close(serverDone) }) + + // The handshake will complete, but the server will immediately close the connection. + conn, err := cl.Dial(context.Background(), ln.Multiaddr(), serverID) + require.NoError(t, err) + defer conn.Close() + clientDone := make(chan struct{}) + go func() { + defer close(clientDone) + _, err = conn.AcceptStream() + require.Error(t, err) + }() + select { + case <-clientDone: + case <-time.After(5 * time.Second): + t.Fatal("timeout") + } + select { + case <-serverDone: + case <-time.After(5 * time.Second): + t.Fatal("timeout") + } + }) +} + +// TODO: unify somehow. We do the same in libp2pquic. +//go:generate sh -c "mockgen -package libp2pwebtransport_test -destination mock_connection_gater_test.go github.com/libp2p/go-libp2p/core/connmgr ConnectionGater && goimports -w mock_connection_gater_test.go" + +func TestConnectionGaterDialing(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + connGater := NewMockConnectionGater(ctrl) + + serverID, serverKey := newIdentity(t) + tr, err := libp2pwebtransport.New(serverKey, nil, network.NullResourceManager) + require.NoError(t, err) + defer tr.(io.Closer).Close() + ln, err := tr.Listen(ma.StringCast("/ip4/127.0.0.1/udp/0/quic/webtransport")) + require.NoError(t, err) + defer ln.Close() + + connGater.EXPECT().InterceptSecured(network.DirOutbound, serverID, gomock.Any()).Do(func(_ network.Direction, _ peer.ID, addrs network.ConnMultiaddrs) { + require.Equal(t, stripCertHashes(ln.Multiaddr()), addrs.RemoteMultiaddr()) + }) + _, key := newIdentity(t) + cl, err := libp2pwebtransport.New(key, connGater, network.NullResourceManager) + require.NoError(t, err) + defer cl.(io.Closer).Close() + _, err = cl.Dial(context.Background(), ln.Multiaddr(), serverID) + require.EqualError(t, err, "secured connection gated") +} + +func TestConnectionGaterInterceptAccept(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + connGater := NewMockConnectionGater(ctrl) + + serverID, serverKey := newIdentity(t) + tr, err := libp2pwebtransport.New(serverKey, connGater, network.NullResourceManager) + require.NoError(t, err) + defer tr.(io.Closer).Close() + ln, err := tr.Listen(ma.StringCast("/ip4/127.0.0.1/udp/0/quic/webtransport")) + require.NoError(t, err) + defer ln.Close() + + connGater.EXPECT().InterceptAccept(gomock.Any()).Do(func(addrs network.ConnMultiaddrs) { + require.Equal(t, stripCertHashes(ln.Multiaddr()), addrs.LocalMultiaddr()) + require.NotEqual(t, stripCertHashes(ln.Multiaddr()), addrs.RemoteMultiaddr()) + }) + + _, key := newIdentity(t) + cl, err := libp2pwebtransport.New(key, nil, network.NullResourceManager) + require.NoError(t, err) + defer cl.(io.Closer).Close() + _, err = cl.Dial(context.Background(), ln.Multiaddr(), serverID) + require.EqualError(t, err, "received status 403") +} + +func TestConnectionGaterInterceptSecured(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + connGater := NewMockConnectionGater(ctrl) + + serverID, serverKey := newIdentity(t) + tr, err := libp2pwebtransport.New(serverKey, connGater, network.NullResourceManager) + require.NoError(t, err) + defer tr.(io.Closer).Close() + ln, err := tr.Listen(ma.StringCast("/ip4/127.0.0.1/udp/0/quic/webtransport")) + require.NoError(t, err) + defer ln.Close() + + clientID, key := newIdentity(t) + cl, err := libp2pwebtransport.New(key, nil, network.NullResourceManager) + require.NoError(t, err) + defer cl.(io.Closer).Close() + + connGater.EXPECT().InterceptAccept(gomock.Any()).Return(true) + connGater.EXPECT().InterceptSecured(network.DirInbound, clientID, gomock.Any()).Do(func(_ network.Direction, _ peer.ID, addrs network.ConnMultiaddrs) { + require.Equal(t, stripCertHashes(ln.Multiaddr()), addrs.LocalMultiaddr()) + require.NotEqual(t, stripCertHashes(ln.Multiaddr()), addrs.RemoteMultiaddr()) + }) + // The handshake will complete, but the server will immediately close the connection. + conn, err := cl.Dial(context.Background(), ln.Multiaddr(), serverID) + require.NoError(t, err) + defer conn.Close() + done := make(chan struct{}) + go func() { + defer close(done) + _, err = conn.AcceptStream() + require.Error(t, err) + }() + select { + case <-done: + case <-time.After(5 * time.Second): + t.Fatal("timeout") + } +} + +func getTLSConf(t *testing.T, ip net.IP, start, end time.Time) *tls.Config { + t.Helper() + certTempl := &x509.Certificate{ + SerialNumber: big.NewInt(1234), + Subject: pkix.Name{Organization: []string{"webtransport"}}, + NotBefore: start, + NotAfter: end, + IsCA: true, + ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth, x509.ExtKeyUsageServerAuth}, + KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageCertSign, + BasicConstraintsValid: true, + IPAddresses: []net.IP{ip}, + } + priv, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + require.NoError(t, err) + caBytes, err := x509.CreateCertificate(rand.Reader, certTempl, certTempl, &priv.PublicKey, priv) + require.NoError(t, err) + cert, err := x509.ParseCertificate(caBytes) + require.NoError(t, err) + return &tls.Config{ + Certificates: []tls.Certificate{{ + Certificate: [][]byte{cert.Raw}, + PrivateKey: priv, + Leaf: cert, + }}, + } +} + +func TestStaticTLSConf(t *testing.T) { + tlsConf := getTLSConf(t, net.ParseIP("127.0.0.1"), time.Now(), time.Now().Add(365*24*time.Hour)) + + serverID, serverKey := newIdentity(t) + tr, err := libp2pwebtransport.New(serverKey, nil, network.NullResourceManager, libp2pwebtransport.WithTLSConfig(tlsConf)) + require.NoError(t, err) + defer tr.(io.Closer).Close() + ln, err := tr.Listen(ma.StringCast("/ip4/127.0.0.1/udp/0/quic/webtransport")) + require.NoError(t, err) + defer ln.Close() + require.Empty(t, extractCertHashes(ln.Multiaddr()), "listener address shouldn't contain any certhash") + + t.Run("fails when the certificate is invalid", func(t *testing.T) { + _, key := newIdentity(t) + cl, err := libp2pwebtransport.New(key, nil, network.NullResourceManager) + require.NoError(t, err) + defer cl.(io.Closer).Close() + + _, err = cl.Dial(context.Background(), ln.Multiaddr(), serverID) + require.Error(t, err) + if !strings.Contains(err.Error(), "certificate is not trusted") && + !strings.Contains(err.Error(), "certificate signed by unknown authority") { + t.Fatalf("expected a certificate error, got %+v", err) + } + }) + + t.Run("fails when dialing with a wrong certhash", func(t *testing.T) { + _, key := newIdentity(t) + cl, err := libp2pwebtransport.New(key, nil, network.NullResourceManager) + require.NoError(t, err) + defer cl.(io.Closer).Close() + + addr := ln.Multiaddr().Encapsulate(getCerthashComponent(t, []byte("foo"))) + _, err = cl.Dial(context.Background(), addr, serverID) + require.Error(t, err) + require.Contains(t, err.Error(), "cert hash not found") + }) + + t.Run("accepts a valid TLS certificate", func(t *testing.T) { + _, key := newIdentity(t) + store := x509.NewCertPool() + store.AddCert(tlsConf.Certificates[0].Leaf) + tlsConf := &tls.Config{RootCAs: store} + cl, err := libp2pwebtransport.New(key, nil, network.NullResourceManager, libp2pwebtransport.WithTLSClientConfig(tlsConf)) + require.NoError(t, err) + defer cl.(io.Closer).Close() + + conn, err := cl.Dial(context.Background(), ln.Multiaddr(), serverID) + require.NoError(t, err) + defer conn.Close() + }) +}