Skip to content

Commit

Permalink
feat: gracefully serve, shutdown http node
Browse files Browse the repository at this point in the history
  • Loading branch information
siyul-park committed Nov 29, 2023
1 parent a1012ed commit 3747184
Show file tree
Hide file tree
Showing 4 changed files with 69 additions and 16 deletions.
22 changes: 21 additions & 1 deletion pkg/plugin/networkx/builder.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
package networkx

import (
"context"
"time"

"github.com/siyul-park/uniflow/pkg/hook"
"github.com/siyul-park/uniflow/pkg/node"
"github.com/siyul-park/uniflow/pkg/scheme"
Expand All @@ -11,7 +14,24 @@ func AddToHooks() func(*hook.Hook) error {
return func(h *hook.Hook) error {
h.AddLoadHook(symbol.LoadHookFunc(func(n node.Node) error {
if n, ok := n.(*HTTPNode); ok {
go func() { n.Start() }()
errChan := make(chan error)

go func() {
if err := n.Serve(); err != nil {
errChan <- err
}
}()

return n.WaitForListen(errChan)
}
return nil
}))
h.AddUnloadHook(symbol.UnloadHookFunc(func(n node.Node) error {
if n, ok := n.(*HTTPNode); ok {
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
defer cancel()

return n.Shutdown(ctx)
}
return nil
}))
Expand Down
6 changes: 5 additions & 1 deletion pkg/plugin/networkx/builder_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,14 +23,18 @@ func TestAddToHooks(t *testing.T) {
Address: fmt.Sprintf(":%d", port),
})

hk.Load(n)
err = hk.Load(n)
assert.NoError(t, err)

errChan := make(chan error)

err = n.WaitForListen(errChan)

assert.NoError(t, err)
assert.NoError(t, n.Close())

err = hk.Unload(n)
assert.NoError(t, err)
}

func TestAddToScheme(t *testing.T) {
Expand Down
41 changes: 32 additions & 9 deletions pkg/plugin/networkx/http.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,13 @@ import (
)

type (
// HTTPNodeConfig represents the configuration of an HTTP node.
HTTPNodeConfig struct {
ID ulid.ULID
Address string
}

// HTTPNode represents a node based on the HTTP protocol.
HTTPNode struct {
id ulid.ULID
address string
Expand All @@ -41,6 +44,13 @@ type (
mu sync.RWMutex
}

// HTTPSpec represents the specification of an HTTP node.
HTTPSpec struct {
scheme.SpecMeta `map:",inline"`
Address string `map:"address"`
}

// HTTPPayload represents the payload for HTTP requests and responses.
HTTPPayload struct {
Proto string `map:"proto,omitempty"`
Path string `map:"path,omitempty"`
Expand All @@ -52,11 +62,6 @@ type (
Status int `map:"status"`
}

HTTPSpec struct {
scheme.SpecMeta `map:",inline"`
Address string `map:"address"`
}

tcpKeepAliveListener struct {
*net.TCPListener
}
Expand All @@ -66,9 +71,7 @@ const (
KindHTTP = "http"
)

var _ node.Node = &HTTPNode{}
var _ http.Handler = &HTTPNode{}

// Commonly used HTTP header constants.
const (
HeaderAccept = "Accept"
HeaderAcceptCharset = "Accept-Charset"
Expand Down Expand Up @@ -158,6 +161,7 @@ const (
HeaderReferrerPolicy = "Referrer-Policy"
)

// HTTP error response payload.
var (
BadRequest = NewHTTPPayload(http.StatusBadRequest) // HTTP 400 Bad Request
Unauthorized = NewHTTPPayload(http.StatusUnauthorized) // HTTP 401 Unauthorized
Expand Down Expand Up @@ -205,6 +209,9 @@ var (
ErrInvalidListenerNetwork = errors.New("invalid listener network")
)

var _ node.Node = &HTTPNode{}
var _ http.Handler = &HTTPNode{}

var (
forbiddenResponseHeaderRegexps []*regexp.Regexp
)
Expand Down Expand Up @@ -255,6 +262,7 @@ func init() {
}
}

// NewHTTPNode creates a new HTTPNode with the given configuration.
func NewHTTPNode(config HTTPNodeConfig) *HTTPNode {
id := config.ID
address := config.Address
Expand All @@ -278,13 +286,15 @@ func NewHTTPNode(config HTTPNodeConfig) *HTTPNode {
return n
}

// ID returns the ID of the HTTPNode.
func (n *HTTPNode) ID() ulid.ULID {
n.mu.RLock()
defer n.mu.RUnlock()

return n.id
}

// Port returns the port with the given name and a boolean indicating success.
func (n *HTTPNode) Port(name string) (*port.Port, bool) {
n.mu.RLock()
defer n.mu.RUnlock()
Expand All @@ -304,6 +314,7 @@ func (n *HTTPNode) Port(name string) (*port.Port, bool) {
return nil, false
}

// ListenerAddr returns the address of the listener associated with the HTTPNode.
func (n *HTTPNode) ListenerAddr() net.Addr {
n.mu.RLock()
defer n.mu.RUnlock()
Expand All @@ -313,6 +324,7 @@ func (n *HTTPNode) ListenerAddr() net.Addr {
return n.listener.Addr()
}

// WaitForListen waits for the HTTPNode to start listening.
func (n *HTTPNode) WaitForListen(errChan <-chan error) error {
ctx, cancel := context.WithTimeout(context.Background(), 200*time.Millisecond)
defer cancel()
Expand All @@ -338,7 +350,8 @@ func (n *HTTPNode) WaitForListen(errChan <-chan error) error {
}
}

func (n *HTTPNode) Start() error {
// Serve starts serving HTTP requests.
func (n *HTTPNode) Serve() error {
n.mu.Lock()
n.server.Addr = n.address
if err := n.configureServer(); err != nil {
Expand All @@ -349,6 +362,15 @@ func (n *HTTPNode) Start() error {
return n.server.Serve(n.listener)
}

// Shutdown gracefully shuts down the HTTP server.
func (n *HTTPNode) Shutdown(ctx context.Context) error {
n.mu.Lock()
defer n.mu.Unlock()

return n.server.Shutdown(ctx)
}

// Close closes the HTTPNode.
func (n *HTTPNode) Close() error {
n.mu.RLock()
defer n.mu.RUnlock()
Expand All @@ -364,6 +386,7 @@ func (n *HTTPNode) Close() error {
return nil
}

// ServeHTTP handles HTTP requests for the HTTPNode.
func (n *HTTPNode) ServeHTTP(w http.ResponseWriter, r *http.Request) {
n.mu.RLock()
defer n.mu.RUnlock()
Expand Down
16 changes: 11 additions & 5 deletions pkg/plugin/networkx/http_test.go
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
package networkx

import (
"context"
"fmt"
"net/http"
"net/http/httptest"
"testing"
"time"

"github.com/phayes/freeport"
"github.com/siyul-park/uniflow/pkg/node"
Expand All @@ -25,7 +27,7 @@ func TestNewHTTPNode(t *testing.T) {
assert.NotNil(t, n)
assert.NotZero(t, n.ID())

_ = n.Close()
assert.NoError(t, n.Close())
}

func TestHTTPNode_Port(t *testing.T) {
Expand All @@ -35,7 +37,7 @@ func TestHTTPNode_Port(t *testing.T) {
n := NewHTTPNode(HTTPNodeConfig{
Address: fmt.Sprintf(":%d", port),
})
defer func() { _ = n.Close() }()
defer n.Close()

p, ok := n.Port(node.PortIO)
assert.True(t, ok)
Expand All @@ -54,26 +56,30 @@ func TestHTTPNode_Port(t *testing.T) {
assert.NotNil(t, p)
}

func TestHTTPNode_StartAndClose(t *testing.T) {
func TestHTTPNode_ServeAndShutdown(t *testing.T) {
port, err := freeport.GetFreePort()
assert.NoError(t, err)

n := NewHTTPNode(HTTPNodeConfig{
Address: fmt.Sprintf(":%d", port),
})
defer n.Close()

errChan := make(chan error)

go func() {
if err := n.Start(); err != nil {
if err := n.Serve(); err != nil {
errChan <- err
}
}()

err = n.WaitForListen(errChan)

ctx, cancel := context.WithTimeout(context.Background(), time.Second)
defer cancel()

assert.NoError(t, err)
assert.NoError(t, n.Close())
assert.NoError(t, n.Shutdown(ctx))
}

func TestHTTPNode_ServeHTTP(t *testing.T) {
Expand Down

0 comments on commit 3747184

Please sign in to comment.