Skip to content

Commit

Permalink
feat: support gRPC graceful shutdown
Browse files Browse the repository at this point in the history
  • Loading branch information
DMwangnima committed Sep 19, 2024
1 parent 4e1dbe9 commit 67b4624
Show file tree
Hide file tree
Showing 6 changed files with 206 additions and 22 deletions.
2 changes: 1 addition & 1 deletion pkg/remote/trans/nphttp2/conn_pool.go
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ func (p *connPool) newTransport(ctx context.Context, dialer remote.Dialer, netwo
opts,
p.remoteService,
func(grpc.GoAwayReason) {
// do nothing
p.Clean(network, address)
},
func() {
// do nothing
Expand Down
52 changes: 52 additions & 0 deletions pkg/remote/trans/nphttp2/grpc/graceful_shutdown_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
/*
* Copyright 2024 CloudWeGo Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package grpc

import (
"context"
"math"
"testing"
"time"

"github.com/cloudwego/kitex/internal/test"
)

func TestGracefulShutdown(t *testing.T) {
srv, cli := setUp(t, 0, math.MaxUint32, gracefulShutdown)
defer cli.Close(errSelfCloseForTest)

stream, err := cli.NewStream(context.Background(), &CallHdr{})
test.Assert(t, err == nil, err)
<-srv.srvReady
go srv.gracefulShutdown()
err = cli.Write(stream, nil, []byte("hello"), &Options{})
test.Assert(t, err == nil, err)
msg := make([]byte, 5)
num, err := stream.Read(msg)
test.Assert(t, err == nil, err)
test.Assert(t, num == 5, num)
_, err = cli.NewStream(context.Background(), &CallHdr{})
test.Assert(t, err != nil, err)
t.Logf("NewStream err: %v", err)
time.Sleep(1 * time.Second)
err = cli.Write(stream, nil, []byte("hello"), &Options{})
test.Assert(t, err != nil, err)
t.Logf("After timeout, Write err: %v", err)
_, err = stream.Read(msg)
test.Assert(t, err != nil, err)
t.Logf("After timeout, Read err: %v", err)
}
30 changes: 21 additions & 9 deletions pkg/remote/trans/nphttp2/grpc/http2_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -468,9 +468,10 @@ func (t *http2Client) NewStream(ctx context.Context, callHdr *CallHdr) (_ *Strea
s.id = h.streamID
s.fc = &inFlow{limit: uint32(t.initialWindowSize)}
t.mu.Lock()
if t.activeStreams == nil { // Can be niled from Close().
// Don't create a stream if the transport is in a state of graceful shutdown or already closed
if t.state == draining || t.activeStreams == nil { // Can be niled from Close().
t.mu.Unlock()
return false // Don't create a stream if the transport is already closed.
return false
}
t.activeStreams[s.id] = s
t.mu.Unlock()
Expand Down Expand Up @@ -917,29 +918,40 @@ func (t *http2Client) handleGoAway(f *grpcframe.GoAwayFrame) {
// Notify the clientconn about the GOAWAY before we set the state to
// draining, to allow the client to stop attempting to create streams
// before disallowing new streams on this connection.
if t.onGoAway != nil {
t.onGoAway(t.goAwayReason)
if t.state != draining {
if t.onGoAway != nil {
t.onGoAway(t.goAwayReason)
}
t.state = draining
}
t.state = draining
}
// All streams with IDs greater than the GoAwayId
// and smaller than the previous GoAway ID should be killed.
upperLimit := t.prevGoAwayID
if upperLimit == 0 { // This is the first GoAway Frame.
upperLimit = math.MaxUint32 // Kill all streams after the GoAway ID.
}
t.prevGoAwayID = id
active := len(t.activeStreams)
if active <= 0 {
t.mu.Unlock()
t.Close(connectionErrorf(true, nil, "received goaway and there are no active streams"))
return
}

var unprocessedStream []*Stream
for streamID, stream := range t.activeStreams {
if streamID > id && streamID <= upperLimit {
// The stream was unprocessed by the server.
atomic.StoreUint32(&stream.unprocessed, 1)
unprocessedStream = append(unprocessedStream, stream)
t.closeStream(stream, errStreamDrain, false, http2.ErrCodeNo, statusGoAway, nil, false)
}
}
t.prevGoAwayID = id
active := len(t.activeStreams)
t.mu.Unlock()
if active == 0 {
t.Close(connectionErrorf(true, nil, "received goaway and there are no active streams"))

for _, stream := range unprocessedStream {
t.closeStream(stream, errStreamDrain, false, http2.ErrCodeNo, statusGoAway, nil, false)
}
}

Expand Down
8 changes: 3 additions & 5 deletions pkg/remote/trans/nphttp2/grpc/http2_server.go
Original file line number Diff line number Diff line change
Expand Up @@ -1081,10 +1081,8 @@ func (t *http2Server) outgoingGoAwayHandler(g *goAway) (bool, error) {
if err := t.framer.WriteGoAway(sid, g.code, g.debugData); err != nil {
return false, err
}
t.framer.writer.Flush()
if g.closeConn {
// Abruptly close the connection following the GoAway (via
// loopywriter). But flush out what's inside the buffer first.
t.framer.writer.Flush()
return false, fmt.Errorf("transport: Connection closing")
}
return true, nil
Expand All @@ -1096,15 +1094,15 @@ func (t *http2Server) outgoingGoAwayHandler(g *goAway) (bool, error) {
// originated before the GoAway reaches the client.
// After getting the ack or timer expiration send out another GoAway this
// time with an ID of the max stream server intends to process.
if err := t.framer.WriteGoAway(math.MaxUint32, http2.ErrCodeNo, []byte{}); err != nil {
if err := t.framer.WriteGoAway(math.MaxUint32, http2.ErrCodeNo, g.debugData); err != nil {
return false, err
}
if err := t.framer.WritePing(false, goAwayPing.data); err != nil {
return false, err
}

gofunc.RecoverGoFuncWithInfo(context.Background(), func() {
timer := time.NewTimer(time.Minute)
timer := time.NewTimer(10 * time.Second)
defer timer.Stop()
select {
case <-t.drainChan:
Expand Down
82 changes: 76 additions & 6 deletions pkg/remote/trans/nphttp2/grpc/transport_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,10 @@ type server struct {
conns map[ServerTransport]bool
h *testStreamHandler
ready chan struct{}
hdlWG sync.WaitGroup
transWG sync.WaitGroup

srvReady chan struct{}
}

var (
Expand All @@ -77,6 +81,7 @@ func init() {

type testStreamHandler struct {
t *http2Server
srv *server
notify chan struct{}
getNotified chan struct{}
}
Expand All @@ -92,6 +97,8 @@ const (
invalidHeaderField
delayRead
pingpong

gracefulShutdown
)

func (h *testStreamHandler) handleStreamAndNotify(s *Stream) {
Expand Down Expand Up @@ -292,6 +299,20 @@ func (h *testStreamHandler) handleStreamDelayRead(t *testing.T, s *Stream) {
}
}

func (h *testStreamHandler) gracefulShutdown(t *testing.T, s *Stream) {
close(h.srv.srvReady)
msg := make([]byte, 5)
num, err := s.Read(msg)
test.Assert(t, err == nil, err)
test.Assert(t, num == 5, num)
test.Assert(t, string(msg) == "hello", string(msg))
err = h.t.Write(s, nil, msg, &Options{})
test.Assert(t, err == nil, err)
_, err = s.Read(msg)
test.Assert(t, err != nil, err)
t.Logf("Server-side after timeout err: %v", err)
}

// start starts server. Other goroutines should block on s.readyChan for further operations.
func (s *server) start(t *testing.T, port int, serverConfig *ServerConfig, ht hType) {
// 创建 listener
Expand Down Expand Up @@ -329,6 +350,7 @@ func (s *server) start(t *testing.T, port int, serverConfig *ServerConfig, ht hT
s.conns[transport] = true
h := &testStreamHandler{t: transport.(*http2Server)}
s.h = h
h.srv = s
s.mu.Unlock()
switch ht {
case notifyCall:
Expand Down Expand Up @@ -379,12 +401,26 @@ func (s *server) start(t *testing.T, port int, serverConfig *ServerConfig, ht hT
}, func(ctx context.Context, method string) context.Context {
return ctx
})
case gracefulShutdown:
s.transWG.Add(1)
go func() {
defer s.transWG.Done()
transport.HandleStreams(func(stream *Stream) {
s.hdlWG.Add(1)
go func() {
defer s.hdlWG.Done()
h.gracefulShutdown(t, stream)
}()
}, func(ctx context.Context, method string) context.Context { return ctx })
}()
default:
go transport.HandleStreams(func(s *Stream) {
go h.handleStream(t, s)
}, func(ctx context.Context, method string) context.Context {
return ctx
})
go func() {
transport.HandleStreams(func(s *Stream) {
go h.handleStream(t, s)
}, func(ctx context.Context, method string) context.Context {
return ctx
})
}()
}
return ctx
}
Expand Down Expand Up @@ -434,6 +470,40 @@ func (s *server) stop() {
s.mu.Unlock()
}

func (s *server) gracefulShutdown() {
ctx, cancel := context.WithTimeout(context.Background(), 500*time.Millisecond)
defer cancel()
s.lis.Close()
s.mu.Lock()
for trans := range s.conns {
trans.Drain()
}
s.mu.Unlock()
timeout, _ := ctx.Deadline()
graceTimer := time.NewTimer(time.Until(timeout))
exitCh := make(chan struct{})
go func() {
select {
case <-graceTimer.C:
s.mu.Lock()
for trans := range s.conns {
trans.Close()
}
s.mu.Unlock()
return
case <-exitCh:
return
}
}()
s.hdlWG.Wait()
s.transWG.Wait()
close(exitCh)
s.conns = nil
if err := s.eventLoop.Shutdown(ctx); err != nil {
fmt.Printf("netpoll server exit failed, err=%v", err)
}
}

func (s *server) addr() string {
if s.lis == nil {
return ""
Expand All @@ -442,7 +512,7 @@ func (s *server) addr() string {
}

func setUpServerOnly(t *testing.T, port int, serverConfig *ServerConfig, ht hType) *server {
server := &server{startedErr: make(chan error, 1), ready: make(chan struct{})}
server := &server{startedErr: make(chan error, 1), ready: make(chan struct{}), srvReady: make(chan struct{})}
go server.start(t, port, serverConfig, ht)
server.wait(t, time.Second)
return server
Expand Down
54 changes: 53 additions & 1 deletion pkg/remote/trans/nphttp2/server_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ func newSvrTransHandler(opt *remote.ServerOption) (*svrTransHandler, error) {
opt: opt,
svcSearcher: opt.SvcSearcher,
codec: grpc.NewGRPCCodec(grpc.WithThriftCodec(opt.PayloadCodec)),
transports: make(map[grpcTransport.ServerTransport]struct{}),
}, nil
}

Expand All @@ -72,6 +73,11 @@ type svrTransHandler struct {
svcSearcher remote.ServiceSearcher
inkHdlFunc endpoint.Endpoint
codec remote.Codec
mu sync.Mutex
transports map[grpcTransport.ServerTransport]struct{}

hdlWG sync.WaitGroup
transWG sync.WaitGroup
}

var prefaceReadAtMost = func() int {
Expand Down Expand Up @@ -119,9 +125,11 @@ func (t *svrTransHandler) Read(ctx context.Context, conn net.Conn, msg remote.Me
func (t *svrTransHandler) OnRead(ctx context.Context, conn net.Conn) error {
svrTrans := ctx.Value(ctxKeySvrTransport).(*SvrTrans)
tr := svrTrans.tr

defer t.transWG.Done()
tr.HandleStreams(func(s *grpcTransport.Stream) {
t.hdlWG.Add(1)
gofunc.GoFunc(ctx, func() {
defer t.hdlWG.Done()
t.handleFunc(s, svrTrans, conn)
})
}, func(ctx context.Context, method string) context.Context {
Expand Down Expand Up @@ -315,6 +323,10 @@ func (t *svrTransHandler) OnActive(ctx context.Context, conn net.Conn) (context.
if err != nil {
return nil, err
}
t.transWG.Add(1)
t.mu.Lock()
t.transports[tr] = struct{}{}
t.mu.Unlock()
pool := &sync.Pool{
New: func() interface{} {
// init rpcinfo
Expand All @@ -329,6 +341,9 @@ func (t *svrTransHandler) OnActive(ctx context.Context, conn net.Conn) (context.
// 连接关闭时回调
func (t *svrTransHandler) OnInactive(ctx context.Context, conn net.Conn) {
tr := ctx.Value(ctxKeySvrTransport).(*SvrTrans).tr
t.mu.Lock()
delete(t.transports, tr)
t.mu.Unlock()
tr.Close()
}

Expand All @@ -349,6 +364,43 @@ func (t *svrTransHandler) SetInvokeHandleFunc(inkHdlFunc endpoint.Endpoint) {
func (t *svrTransHandler) SetPipeline(p *remote.TransPipeline) {
}

func (t *svrTransHandler) GracefulShutdown(ctx context.Context) error {
t.mu.Lock()
for trans := range t.transports {
trans.Drain()
}
t.mu.Unlock()

exitCh := make(chan struct{})
// todo: think about a better grace time duration
graceTime := time.Minute * 3
exitTimeout, ok := ctx.Deadline()
if ok {
graceTime = time.Until(exitTimeout)
}
graceTimer := time.NewTimer(graceTime)
gofunc.GoFunc(ctx, func() {
select {
case <-graceTimer.C:
t.mu.Lock()
for trans := range t.transports {
// use CloseWithErr
trans.Close()
}
t.mu.Unlock()
return
case <-exitCh:
return
}
})

t.hdlWG.Wait()
t.transWG.Wait()
close(exitCh)

return nil
}

func (t *svrTransHandler) startTracer(ctx context.Context, ri rpcinfo.RPCInfo) context.Context {
c := t.opt.TracerCtl.DoStart(ctx, ri)
return c
Expand Down

0 comments on commit 67b4624

Please sign in to comment.