Skip to content

Commit

Permalink
merge go-libp2p-webtransport here
Browse files Browse the repository at this point in the history
  • Loading branch information
marten-seemann committed Sep 3, 2022
2 parents 366cc3d + 97e739f commit fab7475
Show file tree
Hide file tree
Showing 16 changed files with 2,304 additions and 0 deletions.
181 changes: 181 additions & 0 deletions p2p/transport/webtransport/cert_manager.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,181 @@
package libp2pwebtransport

import (
"bytes"
"context"
"crypto/sha256"
"crypto/tls"
"fmt"
"sync"
"time"

"github.com/benbjohnson/clock"
ma "github.com/multiformats/go-multiaddr"
"github.com/multiformats/go-multibase"
"github.com/multiformats/go-multihash"
)

// Allow for a bit of clock skew.
// When we generate a certificate, the NotBefore time is set to clockSkewAllowance before the current time.
// Similarly, we stop using a certificate one clockSkewAllowance before its expiry time.
const clockSkewAllowance = time.Hour

type certConfig struct {
tlsConf *tls.Config
sha256 [32]byte // cached from the tlsConf
}

func (c *certConfig) Start() time.Time { return c.tlsConf.Certificates[0].Leaf.NotBefore }
func (c *certConfig) End() time.Time { return c.tlsConf.Certificates[0].Leaf.NotAfter }

func newCertConfig(start, end time.Time) (*certConfig, error) {
conf, err := getTLSConf(start, end)
if err != nil {
return nil, err
}
return &certConfig{
tlsConf: conf,
sha256: sha256.Sum256(conf.Certificates[0].Leaf.Raw),
}, nil
}

// Certificate renewal logic:
// 1. On startup, we generate one cert that is valid from now (-1h, to allow for clock skew), and another
// cert that is valid from the expiry date of the first certificate (again, with allowance for clock skew).
// 2. Once we reach 1h before expiry of the first certificate, we switch over to the second certificate.
// At the same time, we stop advertising the certhash of the first cert and generate the next cert.
type certManager struct {
clock clock.Clock
ctx context.Context
ctxCancel context.CancelFunc
refCount sync.WaitGroup

mx sync.RWMutex
lastConfig *certConfig // initially nil
currentConfig *certConfig
nextConfig *certConfig // nil until we have passed half the certValidity of the current config
addrComp ma.Multiaddr
}

func newCertManager(clock clock.Clock) (*certManager, error) {
m := &certManager{clock: clock}
m.ctx, m.ctxCancel = context.WithCancel(context.Background())
if err := m.init(); err != nil {
return nil, err
}

m.background()
return m, nil
}

func (m *certManager) init() error {
start := m.clock.Now().Add(-clockSkewAllowance)
var err error
m.nextConfig, err = newCertConfig(start, start.Add(certValidity))
if err != nil {
return err
}
return m.rollConfig()
}

func (m *certManager) rollConfig() error {
// We stop using the current certificate clockSkewAllowance before its expiry time.
// At this point, the next certificate needs to be valid for one clockSkewAllowance.
nextStart := m.nextConfig.End().Add(-2 * clockSkewAllowance)
c, err := newCertConfig(nextStart, nextStart.Add(certValidity))
if err != nil {
return err
}
m.lastConfig = m.currentConfig
m.currentConfig = m.nextConfig
m.nextConfig = c
return m.cacheAddrComponent()
}

func (m *certManager) background() {
d := m.currentConfig.End().Add(-clockSkewAllowance).Sub(m.clock.Now())
log.Debugw("setting timer", "duration", d.String())
t := m.clock.Timer(d)
m.refCount.Add(1)

go func() {
defer m.refCount.Done()
defer t.Stop()

for {
select {
case <-m.ctx.Done():
return
case now := <-t.C:
m.mx.Lock()
if err := m.rollConfig(); err != nil {
log.Errorw("rolling config failed", "error", err)
}
d := m.currentConfig.End().Add(-clockSkewAllowance).Sub(now)
log.Debugw("rolling certificates", "next", d.String())
t.Reset(d)
m.mx.Unlock()
}
}
}()
}

func (m *certManager) GetConfig() *tls.Config {
m.mx.RLock()
defer m.mx.RUnlock()
return m.currentConfig.tlsConf
}

func (m *certManager) AddrComponent() ma.Multiaddr {
m.mx.RLock()
defer m.mx.RUnlock()
return m.addrComp
}

func (m *certManager) Verify(hashes []multihash.DecodedMultihash) error {
for _, h := range hashes {
if h.Code != multihash.SHA2_256 {
return fmt.Errorf("expected SHA256 hash, got %d", h.Code)
}
if !bytes.Equal(h.Digest, m.currentConfig.sha256[:]) &&
(m.nextConfig == nil || !bytes.Equal(h.Digest, m.nextConfig.sha256[:])) &&
(m.lastConfig == nil || !bytes.Equal(h.Digest, m.lastConfig.sha256[:])) {
return fmt.Errorf("found unexpected hash: %+x", h.Digest)
}
}
return nil
}

func (m *certManager) cacheAddrComponent() error {
addr, err := m.addrComponentForCert(m.currentConfig.sha256[:])
if err != nil {
return err
}
if m.nextConfig != nil {
comp, err := m.addrComponentForCert(m.nextConfig.sha256[:])
if err != nil {
return err
}
addr = addr.Encapsulate(comp)
}
m.addrComp = addr
return nil
}

func (m *certManager) addrComponentForCert(hash []byte) (ma.Multiaddr, error) {
mh, err := multihash.Encode(hash, multihash.SHA2_256)
if err != nil {
return nil, err
}
certStr, err := multibase.Encode(multibase.Base58BTC, mh)
if err != nil {
return nil, err
}
return ma.NewComponent(ma.ProtocolWithCode(ma.P_CERTHASH).Name, certStr)
}

func (m *certManager) Close() error {
m.ctxCancel()
m.refCount.Wait()
return nil
}
102 changes: 102 additions & 0 deletions p2p/transport/webtransport/cert_manager_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
package libp2pwebtransport

import (
"crypto/sha256"
"crypto/tls"
"testing"
"time"

"github.com/benbjohnson/clock"
ma "github.com/multiformats/go-multiaddr"
"github.com/multiformats/go-multibase"
"github.com/multiformats/go-multihash"
"github.com/stretchr/testify/require"
)

func certificateHashFromTLSConfig(c *tls.Config) [32]byte {
return sha256.Sum256(c.Certificates[0].Certificate[0])
}

func splitMultiaddr(addr ma.Multiaddr) []ma.Component {
var components []ma.Component
ma.ForEach(addr, func(c ma.Component) bool {
components = append(components, c)
return true
})
return components
}

func certHashFromComponent(t *testing.T, comp ma.Component) []byte {
t.Helper()
_, data, err := multibase.Decode(comp.Value())
require.NoError(t, err)
mh, err := multihash.Decode(data)
require.NoError(t, err)
require.Equal(t, uint64(multihash.SHA2_256), mh.Code)
return mh.Digest
}

func TestInitialCert(t *testing.T) {
cl := clock.NewMock()
cl.Add(1234567 * time.Hour)
m, err := newCertManager(cl)
require.NoError(t, err)
defer m.Close()

conf := m.GetConfig()
require.Len(t, conf.Certificates, 1)
cert := conf.Certificates[0]
require.Equal(t, cl.Now().Add(-clockSkewAllowance).UTC(), cert.Leaf.NotBefore)
require.Equal(t, cert.Leaf.NotBefore.Add(certValidity), cert.Leaf.NotAfter)
addr := m.AddrComponent()
components := splitMultiaddr(addr)
require.Len(t, components, 2)
require.Equal(t, ma.P_CERTHASH, components[0].Protocol().Code)
hash := certificateHashFromTLSConfig(conf)
require.Equal(t, hash[:], certHashFromComponent(t, components[0]))
require.Equal(t, ma.P_CERTHASH, components[1].Protocol().Code)
}

func TestCertRenewal(t *testing.T) {
cl := clock.NewMock()
m, err := newCertManager(cl)
require.NoError(t, err)
defer m.Close()

firstConf := m.GetConfig()
first := splitMultiaddr(m.AddrComponent())
require.Len(t, first, 2)
require.NotEqual(t, first[0].Value(), first[1].Value(), "the hashes should differ")
// wait for a new certificate to be generated
cl.Add(certValidity - 2*clockSkewAllowance - time.Second)
require.Never(t, func() bool {
for i, c := range splitMultiaddr(m.AddrComponent()) {
if c.Value() != first[i].Value() {
return true
}
}
return false
}, 100*time.Millisecond, 10*time.Millisecond)
cl.Add(2 * time.Second)
require.Eventually(t, func() bool { return m.GetConfig() != firstConf }, 200*time.Millisecond, 10*time.Millisecond)
secondConf := m.GetConfig()

second := splitMultiaddr(m.AddrComponent())
require.Len(t, second, 2)
for _, c := range second {
require.Equal(t, ma.P_CERTHASH, c.Protocol().Code)
}
// check that the 2nd certificate from the beginning was rolled over to be the 1st certificate
require.Equal(t, first[1].Value(), second[0].Value())
require.NotEqual(t, first[0].Value(), second[1].Value())

cl.Add(certValidity - 2*clockSkewAllowance + time.Second)
require.Eventually(t, func() bool { return m.GetConfig() != secondConf }, 200*time.Millisecond, 10*time.Millisecond)
third := splitMultiaddr(m.AddrComponent())
require.Len(t, third, 2)
for _, c := range third {
require.Equal(t, ma.P_CERTHASH, c.Protocol().Code)
}
// check that the 2nd certificate from the beginning was rolled over to be the 1st certificate
require.Equal(t, second[1].Value(), third[0].Value())
}
65 changes: 65 additions & 0 deletions p2p/transport/webtransport/conn.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
package libp2pwebtransport

import (
"context"

"github.com/libp2p/go-libp2p/core/network"
tpt "github.com/libp2p/go-libp2p/core/transport"

"github.com/marten-seemann/webtransport-go"
ma "github.com/multiformats/go-multiaddr"
)

type connSecurityMultiaddrs interface {
network.ConnMultiaddrs
network.ConnSecurity
}

type connSecurityMultiaddrsImpl struct {
network.ConnSecurity
network.ConnMultiaddrs
}

type connMultiaddrs struct {
local, remote ma.Multiaddr
}

var _ network.ConnMultiaddrs = &connMultiaddrs{}

func (c *connMultiaddrs) LocalMultiaddr() ma.Multiaddr { return c.local }
func (c *connMultiaddrs) RemoteMultiaddr() ma.Multiaddr { return c.remote }

type conn struct {
connSecurityMultiaddrs

transport tpt.Transport
session *webtransport.Session

scope network.ConnScope
}

var _ tpt.CapableConn = &conn{}

func newConn(tr tpt.Transport, sess *webtransport.Session, sconn connSecurityMultiaddrs, scope network.ConnScope) *conn {
return &conn{
connSecurityMultiaddrs: sconn,
transport: tr,
session: sess,
scope: scope,
}
}

func (c *conn) OpenStream(ctx context.Context) (network.MuxedStream, error) {
str, err := c.session.OpenStreamSync(ctx)
return &stream{str}, err
}

func (c *conn) AcceptStream() (network.MuxedStream, error) {
str, err := c.session.AcceptStream(context.Background())
return &stream{str}, err
}

func (c *conn) Close() error { return c.session.Close() }
func (c *conn) IsClosed() bool { return c.session.Context().Err() != nil }
func (c *conn) Scope() network.ConnScope { return c.scope }
func (c *conn) Transport() tpt.Transport { return c.transport }
Loading

0 comments on commit fab7475

Please sign in to comment.