-
Notifications
You must be signed in to change notification settings - Fork 1.1k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
16 changed files
with
2,304 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,181 @@ | ||
package libp2pwebtransport | ||
|
||
import ( | ||
"bytes" | ||
"context" | ||
"crypto/sha256" | ||
"crypto/tls" | ||
"fmt" | ||
"sync" | ||
"time" | ||
|
||
"github.com/benbjohnson/clock" | ||
ma "github.com/multiformats/go-multiaddr" | ||
"github.com/multiformats/go-multibase" | ||
"github.com/multiformats/go-multihash" | ||
) | ||
|
||
// Allow for a bit of clock skew. | ||
// When we generate a certificate, the NotBefore time is set to clockSkewAllowance before the current time. | ||
// Similarly, we stop using a certificate one clockSkewAllowance before its expiry time. | ||
const clockSkewAllowance = time.Hour | ||
|
||
type certConfig struct { | ||
tlsConf *tls.Config | ||
sha256 [32]byte // cached from the tlsConf | ||
} | ||
|
||
func (c *certConfig) Start() time.Time { return c.tlsConf.Certificates[0].Leaf.NotBefore } | ||
func (c *certConfig) End() time.Time { return c.tlsConf.Certificates[0].Leaf.NotAfter } | ||
|
||
func newCertConfig(start, end time.Time) (*certConfig, error) { | ||
conf, err := getTLSConf(start, end) | ||
if err != nil { | ||
return nil, err | ||
} | ||
return &certConfig{ | ||
tlsConf: conf, | ||
sha256: sha256.Sum256(conf.Certificates[0].Leaf.Raw), | ||
}, nil | ||
} | ||
|
||
// Certificate renewal logic: | ||
// 1. On startup, we generate one cert that is valid from now (-1h, to allow for clock skew), and another | ||
// cert that is valid from the expiry date of the first certificate (again, with allowance for clock skew). | ||
// 2. Once we reach 1h before expiry of the first certificate, we switch over to the second certificate. | ||
// At the same time, we stop advertising the certhash of the first cert and generate the next cert. | ||
type certManager struct { | ||
clock clock.Clock | ||
ctx context.Context | ||
ctxCancel context.CancelFunc | ||
refCount sync.WaitGroup | ||
|
||
mx sync.RWMutex | ||
lastConfig *certConfig // initially nil | ||
currentConfig *certConfig | ||
nextConfig *certConfig // nil until we have passed half the certValidity of the current config | ||
addrComp ma.Multiaddr | ||
} | ||
|
||
func newCertManager(clock clock.Clock) (*certManager, error) { | ||
m := &certManager{clock: clock} | ||
m.ctx, m.ctxCancel = context.WithCancel(context.Background()) | ||
if err := m.init(); err != nil { | ||
return nil, err | ||
} | ||
|
||
m.background() | ||
return m, nil | ||
} | ||
|
||
func (m *certManager) init() error { | ||
start := m.clock.Now().Add(-clockSkewAllowance) | ||
var err error | ||
m.nextConfig, err = newCertConfig(start, start.Add(certValidity)) | ||
if err != nil { | ||
return err | ||
} | ||
return m.rollConfig() | ||
} | ||
|
||
func (m *certManager) rollConfig() error { | ||
// We stop using the current certificate clockSkewAllowance before its expiry time. | ||
// At this point, the next certificate needs to be valid for one clockSkewAllowance. | ||
nextStart := m.nextConfig.End().Add(-2 * clockSkewAllowance) | ||
c, err := newCertConfig(nextStart, nextStart.Add(certValidity)) | ||
if err != nil { | ||
return err | ||
} | ||
m.lastConfig = m.currentConfig | ||
m.currentConfig = m.nextConfig | ||
m.nextConfig = c | ||
return m.cacheAddrComponent() | ||
} | ||
|
||
func (m *certManager) background() { | ||
d := m.currentConfig.End().Add(-clockSkewAllowance).Sub(m.clock.Now()) | ||
log.Debugw("setting timer", "duration", d.String()) | ||
t := m.clock.Timer(d) | ||
m.refCount.Add(1) | ||
|
||
go func() { | ||
defer m.refCount.Done() | ||
defer t.Stop() | ||
|
||
for { | ||
select { | ||
case <-m.ctx.Done(): | ||
return | ||
case now := <-t.C: | ||
m.mx.Lock() | ||
if err := m.rollConfig(); err != nil { | ||
log.Errorw("rolling config failed", "error", err) | ||
} | ||
d := m.currentConfig.End().Add(-clockSkewAllowance).Sub(now) | ||
log.Debugw("rolling certificates", "next", d.String()) | ||
t.Reset(d) | ||
m.mx.Unlock() | ||
} | ||
} | ||
}() | ||
} | ||
|
||
func (m *certManager) GetConfig() *tls.Config { | ||
m.mx.RLock() | ||
defer m.mx.RUnlock() | ||
return m.currentConfig.tlsConf | ||
} | ||
|
||
func (m *certManager) AddrComponent() ma.Multiaddr { | ||
m.mx.RLock() | ||
defer m.mx.RUnlock() | ||
return m.addrComp | ||
} | ||
|
||
func (m *certManager) Verify(hashes []multihash.DecodedMultihash) error { | ||
for _, h := range hashes { | ||
if h.Code != multihash.SHA2_256 { | ||
return fmt.Errorf("expected SHA256 hash, got %d", h.Code) | ||
} | ||
if !bytes.Equal(h.Digest, m.currentConfig.sha256[:]) && | ||
(m.nextConfig == nil || !bytes.Equal(h.Digest, m.nextConfig.sha256[:])) && | ||
(m.lastConfig == nil || !bytes.Equal(h.Digest, m.lastConfig.sha256[:])) { | ||
return fmt.Errorf("found unexpected hash: %+x", h.Digest) | ||
} | ||
} | ||
return nil | ||
} | ||
|
||
func (m *certManager) cacheAddrComponent() error { | ||
addr, err := m.addrComponentForCert(m.currentConfig.sha256[:]) | ||
if err != nil { | ||
return err | ||
} | ||
if m.nextConfig != nil { | ||
comp, err := m.addrComponentForCert(m.nextConfig.sha256[:]) | ||
if err != nil { | ||
return err | ||
} | ||
addr = addr.Encapsulate(comp) | ||
} | ||
m.addrComp = addr | ||
return nil | ||
} | ||
|
||
func (m *certManager) addrComponentForCert(hash []byte) (ma.Multiaddr, error) { | ||
mh, err := multihash.Encode(hash, multihash.SHA2_256) | ||
if err != nil { | ||
return nil, err | ||
} | ||
certStr, err := multibase.Encode(multibase.Base58BTC, mh) | ||
if err != nil { | ||
return nil, err | ||
} | ||
return ma.NewComponent(ma.ProtocolWithCode(ma.P_CERTHASH).Name, certStr) | ||
} | ||
|
||
func (m *certManager) Close() error { | ||
m.ctxCancel() | ||
m.refCount.Wait() | ||
return nil | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,102 @@ | ||
package libp2pwebtransport | ||
|
||
import ( | ||
"crypto/sha256" | ||
"crypto/tls" | ||
"testing" | ||
"time" | ||
|
||
"github.com/benbjohnson/clock" | ||
ma "github.com/multiformats/go-multiaddr" | ||
"github.com/multiformats/go-multibase" | ||
"github.com/multiformats/go-multihash" | ||
"github.com/stretchr/testify/require" | ||
) | ||
|
||
func certificateHashFromTLSConfig(c *tls.Config) [32]byte { | ||
return sha256.Sum256(c.Certificates[0].Certificate[0]) | ||
} | ||
|
||
func splitMultiaddr(addr ma.Multiaddr) []ma.Component { | ||
var components []ma.Component | ||
ma.ForEach(addr, func(c ma.Component) bool { | ||
components = append(components, c) | ||
return true | ||
}) | ||
return components | ||
} | ||
|
||
func certHashFromComponent(t *testing.T, comp ma.Component) []byte { | ||
t.Helper() | ||
_, data, err := multibase.Decode(comp.Value()) | ||
require.NoError(t, err) | ||
mh, err := multihash.Decode(data) | ||
require.NoError(t, err) | ||
require.Equal(t, uint64(multihash.SHA2_256), mh.Code) | ||
return mh.Digest | ||
} | ||
|
||
func TestInitialCert(t *testing.T) { | ||
cl := clock.NewMock() | ||
cl.Add(1234567 * time.Hour) | ||
m, err := newCertManager(cl) | ||
require.NoError(t, err) | ||
defer m.Close() | ||
|
||
conf := m.GetConfig() | ||
require.Len(t, conf.Certificates, 1) | ||
cert := conf.Certificates[0] | ||
require.Equal(t, cl.Now().Add(-clockSkewAllowance).UTC(), cert.Leaf.NotBefore) | ||
require.Equal(t, cert.Leaf.NotBefore.Add(certValidity), cert.Leaf.NotAfter) | ||
addr := m.AddrComponent() | ||
components := splitMultiaddr(addr) | ||
require.Len(t, components, 2) | ||
require.Equal(t, ma.P_CERTHASH, components[0].Protocol().Code) | ||
hash := certificateHashFromTLSConfig(conf) | ||
require.Equal(t, hash[:], certHashFromComponent(t, components[0])) | ||
require.Equal(t, ma.P_CERTHASH, components[1].Protocol().Code) | ||
} | ||
|
||
func TestCertRenewal(t *testing.T) { | ||
cl := clock.NewMock() | ||
m, err := newCertManager(cl) | ||
require.NoError(t, err) | ||
defer m.Close() | ||
|
||
firstConf := m.GetConfig() | ||
first := splitMultiaddr(m.AddrComponent()) | ||
require.Len(t, first, 2) | ||
require.NotEqual(t, first[0].Value(), first[1].Value(), "the hashes should differ") | ||
// wait for a new certificate to be generated | ||
cl.Add(certValidity - 2*clockSkewAllowance - time.Second) | ||
require.Never(t, func() bool { | ||
for i, c := range splitMultiaddr(m.AddrComponent()) { | ||
if c.Value() != first[i].Value() { | ||
return true | ||
} | ||
} | ||
return false | ||
}, 100*time.Millisecond, 10*time.Millisecond) | ||
cl.Add(2 * time.Second) | ||
require.Eventually(t, func() bool { return m.GetConfig() != firstConf }, 200*time.Millisecond, 10*time.Millisecond) | ||
secondConf := m.GetConfig() | ||
|
||
second := splitMultiaddr(m.AddrComponent()) | ||
require.Len(t, second, 2) | ||
for _, c := range second { | ||
require.Equal(t, ma.P_CERTHASH, c.Protocol().Code) | ||
} | ||
// check that the 2nd certificate from the beginning was rolled over to be the 1st certificate | ||
require.Equal(t, first[1].Value(), second[0].Value()) | ||
require.NotEqual(t, first[0].Value(), second[1].Value()) | ||
|
||
cl.Add(certValidity - 2*clockSkewAllowance + time.Second) | ||
require.Eventually(t, func() bool { return m.GetConfig() != secondConf }, 200*time.Millisecond, 10*time.Millisecond) | ||
third := splitMultiaddr(m.AddrComponent()) | ||
require.Len(t, third, 2) | ||
for _, c := range third { | ||
require.Equal(t, ma.P_CERTHASH, c.Protocol().Code) | ||
} | ||
// check that the 2nd certificate from the beginning was rolled over to be the 1st certificate | ||
require.Equal(t, second[1].Value(), third[0].Value()) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,65 @@ | ||
package libp2pwebtransport | ||
|
||
import ( | ||
"context" | ||
|
||
"github.com/libp2p/go-libp2p/core/network" | ||
tpt "github.com/libp2p/go-libp2p/core/transport" | ||
|
||
"github.com/marten-seemann/webtransport-go" | ||
ma "github.com/multiformats/go-multiaddr" | ||
) | ||
|
||
type connSecurityMultiaddrs interface { | ||
network.ConnMultiaddrs | ||
network.ConnSecurity | ||
} | ||
|
||
type connSecurityMultiaddrsImpl struct { | ||
network.ConnSecurity | ||
network.ConnMultiaddrs | ||
} | ||
|
||
type connMultiaddrs struct { | ||
local, remote ma.Multiaddr | ||
} | ||
|
||
var _ network.ConnMultiaddrs = &connMultiaddrs{} | ||
|
||
func (c *connMultiaddrs) LocalMultiaddr() ma.Multiaddr { return c.local } | ||
func (c *connMultiaddrs) RemoteMultiaddr() ma.Multiaddr { return c.remote } | ||
|
||
type conn struct { | ||
connSecurityMultiaddrs | ||
|
||
transport tpt.Transport | ||
session *webtransport.Session | ||
|
||
scope network.ConnScope | ||
} | ||
|
||
var _ tpt.CapableConn = &conn{} | ||
|
||
func newConn(tr tpt.Transport, sess *webtransport.Session, sconn connSecurityMultiaddrs, scope network.ConnScope) *conn { | ||
return &conn{ | ||
connSecurityMultiaddrs: sconn, | ||
transport: tr, | ||
session: sess, | ||
scope: scope, | ||
} | ||
} | ||
|
||
func (c *conn) OpenStream(ctx context.Context) (network.MuxedStream, error) { | ||
str, err := c.session.OpenStreamSync(ctx) | ||
return &stream{str}, err | ||
} | ||
|
||
func (c *conn) AcceptStream() (network.MuxedStream, error) { | ||
str, err := c.session.AcceptStream(context.Background()) | ||
return &stream{str}, err | ||
} | ||
|
||
func (c *conn) Close() error { return c.session.Close() } | ||
func (c *conn) IsClosed() bool { return c.session.Context().Err() != nil } | ||
func (c *conn) Scope() network.ConnScope { return c.scope } | ||
func (c *conn) Transport() tpt.Transport { return c.transport } |
Oops, something went wrong.