Skip to content

Commit

Permalink
feat: context-managed streams (#179)
Browse files Browse the repository at this point in the history
* Modifying stream management within the context of a server or client to properly make use of contexts.

Signed-off-by: Shivansh Vij <shivanshvij@loopholelabs.io>

* Renaming client.ctx to client.baseContext

Signed-off-by: Shivansh Vij <shivanshvij@loopholelabs.io>

* Breaking: Changing `NewServer` to accept a context the same way the client does.

Signed-off-by: Shivansh Vij <shivanshvij@loopholelabs.io>

---------

Signed-off-by: Shivansh Vij <shivanshvij@loopholelabs.io>
  • Loading branch information
ShivanshVij authored Sep 8, 2024
1 parent a83a18c commit 6e1bae7
Show file tree
Hide file tree
Showing 5 changed files with 96 additions and 75 deletions.
8 changes: 6 additions & 2 deletions async.go
Original file line number Diff line number Diff line change
Expand Up @@ -84,13 +84,17 @@ func NewAsync(c net.Conn, logger types.Logger, streamHandler ...NewStreamHandler
conn.logger = noop.New(types.InfoLevel)
}

if len(streamHandler) > 0 {
if len(streamHandler) > 0 && streamHandler[0] != nil {
conn.newStreamHandler = streamHandler[0]
}

conn.wg.Add(3)
conn.wg.Add(1)
go conn.flushLoop()

conn.wg.Add(1)
go conn.readLoop()

conn.wg.Add(1)
go conn.pingLoop()

return
Expand Down
39 changes: 29 additions & 10 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,19 +17,25 @@ import (
type Client struct {
conn *Async
handlerTable HandlerTable
ctx context.Context
options *Options
closed atomic.Bool
wg sync.WaitGroup
heartbeatChannel chan struct{}

baseContext context.Context
baseContextCancel context.CancelFunc

// PacketContext is used to define packet-specific contexts based on the incoming packet
// and is run whenever a new packet arrives
PacketContext func(context.Context, *packet.Packet) context.Context

// UpdateContext is used to update a handler-specific context whenever the returned
// Action from a handler is UPDATE
UpdateContext func(context.Context, *Async) context.Context

// StreamContext is used to update a handler-specific context whenever a new stream is created
// and is run whenever a new stream is created
StreamContext func(context.Context, *Stream) context.Context
}

// NewClient returns an uninitialized frisbee Client with the registered ClientRouter.
Expand All @@ -44,11 +50,14 @@ func NewClient(handlerTable HandlerTable, ctx context.Context, opts ...Option) (
options := loadOptions(opts...)
var heartbeatChannel chan struct{}

baseContext, baseContextCancel := context.WithCancel(ctx)

return &Client{
handlerTable: handlerTable,
ctx: ctx,
options: options,
heartbeatChannel: heartbeatChannel,
handlerTable: handlerTable,
baseContext: baseContext,
baseContextCancel: baseContextCancel,
options: options,
heartbeatChannel: heartbeatChannel,
}, nil
}

Expand Down Expand Up @@ -94,6 +103,7 @@ func (c *Client) Error() error {
// Close closes the frisbee client and kills all the goroutines
func (c *Client) Close() error {
if c.closed.CompareAndSwap(false, true) {
c.baseContextCancel()
err := c.conn.Close()
if err != nil {
return err
Expand Down Expand Up @@ -138,15 +148,24 @@ func (c *Client) Stream(id uint16) *Stream {
return c.conn.NewStream(id)
}

// SetNewStreamHandler sets the callback handler for new streams.
// SetStreamHandler sets the callback handler for new streams.
//
// It's important to note that this handler is called for new streams and if it is
// not set then stream packets will be dropped.
//
// It's also important to note that the handler itself is called in its own goroutine to
// avoid blocking the read lop. This means that the handler must be thread-safe.
func (c *Client) SetNewStreamHandler(handler NewStreamHandler) {
c.conn.SetNewStreamHandler(handler)
// avoid blocking the read loop. This means that the handler must be thread-safe.
func (c *Client) SetStreamHandler(f func(context.Context, *Stream)) {
if f == nil {
c.conn.SetNewStreamHandler(nil)
}
c.conn.SetNewStreamHandler(func(s *Stream) {
streamCtx := c.baseContext
if c.StreamContext != nil {
streamCtx = c.StreamContext(streamCtx, s)
}
f(streamCtx, s)
})
}

// Logger returns the client's logger (useful for ClientRouter functions)
Expand Down Expand Up @@ -174,7 +193,7 @@ func (c *Client) handleConn() {
}
handlerFunc = c.handlerTable[p.Metadata.Operation]
if handlerFunc != nil {
packetCtx := c.ctx
packetCtx := c.baseContext
if c.PacketContext != nil {
packetCtx = c.PacketContext(packetCtx, p)
}
Expand Down
8 changes: 4 additions & 4 deletions client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ func TestClientRaw(t *testing.T) {
}

emptyLogger := logging.Test(t, logging.Noop, t.Name())
s, err := NewServer(serverHandlerTable, WithLogger(emptyLogger))
s, err := NewServer(serverHandlerTable, context.Background(), WithLogger(emptyLogger))
require.NoError(t, err)

s.SetConcurrency(1)
Expand Down Expand Up @@ -144,7 +144,7 @@ func TestClientStaleClose(t *testing.T) {
}

emptyLogger := logging.Test(t, logging.Noop, t.Name())
s, err := NewServer(serverHandlerTable, WithLogger(emptyLogger))
s, err := NewServer(serverHandlerTable, context.Background(), WithLogger(emptyLogger))
require.NoError(t, err)

s.SetConcurrency(1)
Expand Down Expand Up @@ -204,7 +204,7 @@ func BenchmarkThroughputClient(b *testing.B) {
}

emptyLogger := logging.Test(b, logging.Noop, b.Name())
s, err := NewServer(serverHandlerTable, WithLogger(emptyLogger))
s, err := NewServer(serverHandlerTable, context.Background(), WithLogger(emptyLogger))
if err != nil {
b.Fatal(err)
}
Expand Down Expand Up @@ -289,7 +289,7 @@ func BenchmarkThroughputResponseClient(b *testing.B) {
}

emptyLogger := logging.Test(b, logging.Noop, b.Name())
s, err := NewServer(serverHandlerTable, WithLogger(emptyLogger))
s, err := NewServer(serverHandlerTable, context.Background(), WithLogger(emptyLogger))
if err != nil {
b.Fatal(err)
}
Expand Down
84 changes: 41 additions & 43 deletions server.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,16 +16,12 @@ import (
)

var (
BaseContextNil = errors.New("BaseContext cannot be nil")
OnClosedNil = errors.New("OnClosed cannot be nil")
PreWriteNil = errors.New("PreWrite cannot be nil")
StreamHandlerNil = errors.New("StreamHandler cannot be nil")
ListenerNil = errors.New("Listener cannot be nil")
OnClosedNil = errors.New("OnClosed function cannot be nil")
PreWriteNil = errors.New("PreWrite function cannot be nil")
ListenerNil = errors.New("listener cannot be nil")
)

var (
defaultBaseContext = context.Background

defaultOnClosed = func(_ *Async, _ error) {}

defaultPreWrite = func() {}
Expand All @@ -48,8 +44,8 @@ type Server struct {
concurrency uint64
limiter chan struct{}

// baseContext is used to define the base context for this Server and all incoming connections
baseContext func() context.Context
baseContext context.Context
baseContextCancel context.CancelFunc

// onClosed is a function run by the server whenever a connection is closed
onClosed func(*Async, error)
Expand All @@ -64,6 +60,10 @@ type Server struct {
// and is run whenever a new connection is opened
ConnContext func(context.Context, *Async) context.Context

// StreamContext is used to define a stream-specific context based on the incoming stream
// and is run whenever a new stream is opened
StreamContext func(context.Context, *Stream) context.Context

// PacketContext is used to define a handler-specific contexts based on the incoming packet
// and is run whenever a new packet arrives
PacketContext func(context.Context, *packet.Packet) context.Context
Expand All @@ -75,30 +75,25 @@ type Server struct {

// NewServer returns an uninitialized frisbee Server with the registered HandlerTable.
// The Start method must then be called to start the server and listen for connections.
func NewServer(handlerTable HandlerTable, opts ...Option) (*Server, error) {
func NewServer(handlerTable HandlerTable, ctx context.Context, opts ...Option) (*Server, error) {
options := loadOptions(opts...)

baseContext, baseContextCancel := context.WithCancel(ctx)

s := &Server{
options: options,
connections: make(map[*Async]struct{}),
startedCh: make(chan struct{}),
baseContext: defaultBaseContext,
onClosed: defaultOnClosed,
preWrite: defaultPreWrite,
streamHandler: defaultStreamHandler,
options: options,
connections: make(map[*Async]struct{}),
startedCh: make(chan struct{}),
baseContext: baseContext,
baseContextCancel: baseContextCancel,
onClosed: defaultOnClosed,
preWrite: defaultPreWrite,
streamHandler: defaultStreamHandler,
}

return s, s.SetHandlerTable(handlerTable)
}

// SetBaseContext sets the baseContext function for the server. If f is nil, it returns an error.
func (s *Server) SetBaseContext(f func() context.Context) error {
if f == nil {
return BaseContextNil
}
s.baseContext = f
return nil
}

// SetOnClosed sets the onClosed function for the server. If f is nil, it returns an error.
func (s *Server) SetOnClosed(f func(*Async, error)) error {
if f == nil {
Expand All @@ -117,13 +112,14 @@ func (s *Server) SetPreWrite(f func()) error {
return nil
}

// SetStreamHandler sets the streamHandler function for the server. If f is nil, it returns an error.
func (s *Server) SetStreamHandler(f func(*Async, *Stream)) error {
if f == nil {
return StreamHandlerNil
}
// SetStreamHandler sets the streamHandler function for the server.
func (s *Server) SetStreamHandler(f func(context.Context, *Stream)) error {
s.streamHandler = func(stream *Stream) {
f(stream.Conn(), stream)
streamCtx := s.baseContext
if s.StreamContext != nil {
streamCtx = s.StreamContext(streamCtx, stream)
}
f(streamCtx, stream)
}
return nil
}
Expand Down Expand Up @@ -433,7 +429,7 @@ func (s *Server) serveConn(newConn net.Conn) {
}

frisbeeConn := NewAsync(newConn, s.Logger(), s.streamHandler)
connCtx := s.baseContext()
connCtx := s.baseContext
s.connectionsMu.Lock()
if s.shutdown.Load() {
s.wg.Done()
Expand Down Expand Up @@ -467,16 +463,18 @@ func (s *Server) Logger() types.Logger {

// Shutdown shuts down the frisbee server and kills all the goroutines and active connections
func (s *Server) Shutdown() error {
s.shutdown.Store(true)
s.connectionsMu.Lock()
for c := range s.connections {
_ = c.Close()
delete(s.connections, c)
}
s.connectionsMu.Unlock()
defer s.wg.Wait()
if s.listener != nil {
return s.listener.Close()
if s.shutdown.CompareAndSwap(false, true) {
s.baseContextCancel()
s.connectionsMu.Lock()
for c := range s.connections {
_ = c.Close()
delete(s.connections, c)
}
s.connectionsMu.Unlock()
defer s.wg.Wait()
if s.listener != nil {
return s.listener.Close()
}
}
return nil
}
Loading

0 comments on commit 6e1bae7

Please sign in to comment.