Skip to content

Commit

Permalink
Skip one byte reads and writes to be deterministic
Browse files Browse the repository at this point in the history
TLS handshake key generation uses internally `randutil.MaybeReadByte`
which randomly reads one byte from the Rand reader. This makes the
bytes used by handshake process non-deterministic.

To avoid this, we override the state Rand buffer to avoid writing or
reading from it when only one byte is needed.

See https://github.com/golang/go/blob/70491a81113e7003e314451f3e3cf134c4d41dd7/src/crypto/internal/randutil/randutil.go#L25
  • Loading branch information
igolaizola committed Jul 18, 2024
1 parent d1d1515 commit a5ae86d
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 17 deletions.
28 changes: 28 additions & 0 deletions internal/io/io.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,3 +17,31 @@ func (r *OverrideReader) Read(p []byte) (n int, err error) {
}
return r.Reader.Read(p)
}

// SkipOneReader is an io.Reader implementation that skips reading when only
// one byte is requested.
type SkipOneReader struct {
io.Reader
}

// Read implements io.Reader.Read
func (r *SkipOneReader) Read(p []byte) (n int, err error) {
if len(p) == 1 {
return 1, nil
}
return r.Reader.Read(p)
}

// SkipOneWriter is an io.Writer implementation that skips writing when only
// one byte is written.
type SkipOneWriter struct {
io.Writer
}

// Write implements io.Writer.Write
func (r *SkipOneWriter) Write(p []byte) (n int, err error) {
if len(p) == 1 {
return 1, nil
}
return r.Writer.Write(p)
}
38 changes: 21 additions & 17 deletions resumetls.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,21 +44,7 @@ func Server(conn net.Conn, cfg *tls.Config, state *State) (*Conn, error) {
// newConn returns a resumable tls conn
func newConn(tlsConn func(net.Conn, *tls.Config) *tls.Conn, conn net.Conn, cfg *tls.Config, state *State) (*Conn, error) {
if state != nil {
var err error
// Client hello message sometimes consumes more random bytes than the
// ones provided by the state. Probably due to how elliptic curves keys
// are generated.
// We have tested empirically that retrying 10 times is enough to get a
// successful handshake. Here, we set the limit to 20 to be on the safe
// side.
for i := 0; i < 20; i++ {
var c *Conn
c, err = resume(tlsConn, conn, cfg, state)
if err == nil {
return c, err
}
}
return nil, err
return resume(tlsConn, conn, cfg, state)
}
return initialize(tlsConn, conn, cfg), nil
}
Expand All @@ -72,8 +58,17 @@ func initialize(tlsConn func(net.Conn, *tls.Config) *tls.Conn, conn net.Conn, cf
if rnd == nil {
rnd = rand.Reader
}

// TLS handshake key generation uses internally `randutil.MaybeReadByte`
// which randomly reads one byte from the Rand reader. This makes the
// bytes used by handshake process non-deterministic. To avoid this, we
// override the Rand reader to avoid storing in the buffer when only one
// byte is written.
// See https://github.com/golang/go/blob/70491a81113e7003e314451f3e3cf134c4d41dd7/src/crypto/internal/randutil/randutil.go#L25
randWriter := &intio.SkipOneWriter{Writer: randBuf}

ovRand := &intio.OverrideReader{
OverrideReader: io.TeeReader(rnd, randBuf),
OverrideReader: io.TeeReader(rnd, randWriter),
Reader: rnd,
}
ovConn := &intnet.OverrideConn{
Expand All @@ -97,8 +92,17 @@ func resume(tlsConn func(net.Conn, *tls.Config) *tls.Conn, conn net.Conn, cfg *t
if rnd == nil {
rnd = rand.Reader
}

// TLS handshake key generation uses internally `randutil.MaybeReadByte`
// which randomly reads one byte from the Rand reader. This makes the
// bytes used by handshake process non-deterministic. To avoid this, we
// override the state Rand reader to avoid reading from the buffer when only
// one byte is read.
// See https://github.com/golang/go/blob/70491a81113e7003e314451f3e3cf134c4d41dd7/src/crypto/internal/randutil/randutil.go#L25
stateRandReader := &intio.SkipOneReader{Reader: bytes.NewBuffer(state.rand)}

ovRand := &intio.OverrideReader{
OverrideReader: io.MultiReader(bytes.NewBuffer(state.rand), rnd),
OverrideReader: io.MultiReader(stateRandReader, rnd),
Reader: rnd,
}
ovConn := &intnet.OverrideConn{
Expand Down

0 comments on commit a5ae86d

Please sign in to comment.