diff --git a/pkg/remote/trans/nphttp2/conn_pool.go b/pkg/remote/trans/nphttp2/conn_pool.go index cdd1a2f785..56be72761f 100644 --- a/pkg/remote/trans/nphttp2/conn_pool.go +++ b/pkg/remote/trans/nphttp2/conn_pool.go @@ -121,7 +121,10 @@ func (p *connPool) newTransport(ctx context.Context, dialer remote.Dialer, netwo opts, p.remoteService, func(grpc.GoAwayReason) { - // do nothing + // remove connection from the pool. + // we do not need to close this grpc transport manually + // since grpc client is responsible for doing this. + p.conns.Delete(address) }, func() { // do nothing diff --git a/pkg/remote/trans/nphttp2/grpc/controlbuf.go b/pkg/remote/trans/nphttp2/grpc/controlbuf.go index 7dd4701124..b4edd6d7ce 100644 --- a/pkg/remote/trans/nphttp2/grpc/controlbuf.go +++ b/pkg/remote/trans/nphttp2/grpc/controlbuf.go @@ -22,6 +22,7 @@ package grpc import ( "bytes" + "errors" "fmt" "runtime" "sync" @@ -145,10 +146,11 @@ func (h *headerFrame) isTransportResponseFrame() bool { } type cleanupStream struct { - streamID uint32 - rst bool - rstCode http2.ErrCode - onWrite func() + streamID uint32 + rst bool + rstCode http2.ErrCode + onWrite func() + onFinishWrite func() } func (c *cleanupStream) isTransportResponseFrame() bool { return c.rst } // Results in a RST_STREAM @@ -451,19 +453,20 @@ func (c *controlBuffer) get(block bool) (interface{}, error) { select { case <-c.ch: case <-c.done: - c.finish() - return nil, ErrConnClosing + return nil, c.finish(ErrConnClosing) } } } -func (c *controlBuffer) finish() { +func (c *controlBuffer) finish(err error) (rErr error) { c.mu.Lock() if c.err != nil { + rErr = c.err c.mu.Unlock() return } - c.err = ErrConnClosing + c.err = err + rErr = err // There may be headers for streams in the control buffer. // These streams need to be cleaned out since the transport // is still not aware of these yet. @@ -473,10 +476,11 @@ func (c *controlBuffer) finish() { continue } if hdr.onOrphaned != nil { // It will be nil on the server-side. - hdr.onOrphaned(ErrConnClosing) + hdr.onOrphaned(err) } } c.mu.Unlock() + return } type side int @@ -564,6 +568,10 @@ func (l *loopyWriter) run(remoteAddr string) (err error) { klog.Debugf("KITEX: grpc transport loopyWriter.run returning, error=%v, remoteAddr=%s", err, remoteAddr) err = nil } + // make sure the Graceful Shutdown behaviour triggered + if errors.Is(err, errGracefulShutdown) { + l.framer.writer.Flush() + } }() for { it, err := l.cbuf.get(true) @@ -778,6 +786,11 @@ func (l *loopyWriter) outFlowControlSizeRequestHandler(o *outFlowControlSizeRequ } func (l *loopyWriter) cleanupStreamHandler(c *cleanupStream) error { + if c.onFinishWrite != nil { + defer func() { + c.onFinishWrite() + }() + } c.onWrite() if str, ok := l.estdStreams[c.streamID]; ok { // On the server side it could be a trailers-only response or diff --git a/pkg/remote/trans/nphttp2/grpc/controlbuf_test.go b/pkg/remote/trans/nphttp2/grpc/controlbuf_test.go index 643a4e3de5..3f406b21d8 100644 --- a/pkg/remote/trans/nphttp2/grpc/controlbuf_test.go +++ b/pkg/remote/trans/nphttp2/grpc/controlbuf_test.go @@ -71,5 +71,5 @@ func TestControlBuf(t *testing.T) { cb.throttle() // test finish() - cb.finish() + cb.finish(ErrConnClosing) } diff --git a/pkg/remote/trans/nphttp2/grpc/graceful_shutdown_test.go b/pkg/remote/trans/nphttp2/grpc/graceful_shutdown_test.go new file mode 100644 index 0000000000..8052dbabf1 --- /dev/null +++ b/pkg/remote/trans/nphttp2/grpc/graceful_shutdown_test.go @@ -0,0 +1,60 @@ +/* + * Copyright 2024 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package grpc + +import ( + "context" + "errors" + "math" + "strings" + "testing" + + "github.com/cloudwego/kitex/internal/test" + "github.com/cloudwego/kitex/pkg/remote/trans/nphttp2/codes" +) + +func TestGracefulShutdown(t *testing.T) { + onGoAwayCh := make(chan struct{}) + srv, cli := setUpWithOnGoAway(t, 10000, &ServerConfig{MaxStreams: math.MaxUint32}, gracefulShutdown, ConnectOptions{}, func(reason GoAwayReason) { + close(onGoAwayCh) + }) + defer cli.Close(errSelfCloseForTest) + + stream, err := cli.NewStream(context.Background(), &CallHdr{}) + test.Assert(t, err == nil, err) + <-srv.srvReady + finishCh := make(chan struct{}) + go srv.gracefulShutdown(finishCh) + err = cli.Write(stream, nil, []byte("hello"), &Options{}) + test.Assert(t, err == nil, err) + msg := make([]byte, 5) + num, err := stream.Read(msg) + test.Assert(t, err == nil, err) + test.Assert(t, num == 5, num) + // waiting onGoAway triggered + <-onGoAwayCh + // this transport could not create Stream anymore + _, err = cli.NewStream(context.Background(), &CallHdr{}) + test.Assert(t, errors.Is(err, errStreamDrain), err) + // wait for the server transport to be closed + <-finishCh + _, err = stream.Read(msg) + test.Assert(t, err != nil, err) + st := stream.Status() + test.Assert(t, strings.Contains(st.Message(), gracefulShutdownMsg), st.Message()) + test.Assert(t, st.Code() == codes.Unavailable, st.Code()) +} diff --git a/pkg/remote/trans/nphttp2/grpc/http2_client.go b/pkg/remote/trans/nphttp2/grpc/http2_client.go index 58b2eaba01..3766b2a50e 100644 --- a/pkg/remote/trans/nphttp2/grpc/http2_client.go +++ b/pkg/remote/trans/nphttp2/grpc/http2_client.go @@ -22,6 +22,8 @@ package grpc import ( "context" + "errors" + "fmt" "io" "math" "net" @@ -468,9 +470,10 @@ func (t *http2Client) NewStream(ctx context.Context, callHdr *CallHdr) (_ *Strea s.id = h.streamID s.fc = &inFlow{limit: uint32(t.initialWindowSize)} t.mu.Lock() - if t.activeStreams == nil { // Can be niled from Close(). + // Don't create a stream if the transport is in a state of graceful shutdown or already closed + if t.state == draining || t.activeStreams == nil { // Can be niled from Close(). t.mu.Unlock() - return false // Don't create a stream if the transport is already closed. + return false } t.activeStreams[s.id] = s t.mu.Unlock() @@ -533,11 +536,17 @@ func (t *http2Client) CloseStream(s *Stream, err error) { ) if err != nil { rst = true - rstCode = http2.ErrCodeCancel + if errors.Is(err, errGracefulShutdown) { + rstCode = gracefulShutdownCode + } else { + rstCode = http2.ErrCodeCancel + } } t.closeStream(s, err, rst, rstCode, status.Convert(err), nil, false) } +// before invoking closeStream, pls do not hold the t.mu +// because accessing the controlbuf while holding t.mu will cause a deadlock. func (t *http2Client) closeStream(s *Stream, err error, rst bool, rstCode http2.ErrCode, st *status.Status, mdata map[string][]string, eosReceived bool) { // Set stream status to done. if s.swapState(streamDone) == streamDone { @@ -617,7 +626,7 @@ func (t *http2Client) Close(err error) error { t.kpDormancyCond.Signal() } t.mu.Unlock() - t.controlBuf.finish() + t.controlBuf.finish(err) t.cancel() cErr := t.conn.Close() @@ -812,7 +821,13 @@ func (t *http2Client) handleRSTStream(f *http2.RSTStreamFrame) { statusCode = codes.DeadlineExceeded } } - t.closeStream(s, io.EOF, false, http2.ErrCodeNo, status.Newf(statusCode, "stream terminated by RST_STREAM with error code: %v", f.ErrCode), nil, false) + var msg string + if f.ErrCode == gracefulShutdownCode { + msg = gracefulShutdownMsg + } else { + msg = fmt.Sprintf("stream terminated by RST_STREAM with error code: %v", f.ErrCode) + } + t.closeStream(s, io.EOF, false, http2.ErrCodeNo, status.New(statusCode, msg), nil, false) } func (t *http2Client) handleSettings(f *grpcframe.SettingsFrame, isFirst bool) { @@ -917,10 +932,12 @@ func (t *http2Client) handleGoAway(f *grpcframe.GoAwayFrame) { // Notify the clientconn about the GOAWAY before we set the state to // draining, to allow the client to stop attempting to create streams // before disallowing new streams on this connection. - if t.onGoAway != nil { - t.onGoAway(t.goAwayReason) + if t.state != draining { + if t.onGoAway != nil { + t.onGoAway(t.goAwayReason) + } + t.state = draining } - t.state = draining } // All streams with IDs greater than the GoAwayId // and smaller than the previous GoAway ID should be killed. @@ -928,18 +945,29 @@ func (t *http2Client) handleGoAway(f *grpcframe.GoAwayFrame) { if upperLimit == 0 { // This is the first GoAway Frame. upperLimit = math.MaxUint32 // Kill all streams after the GoAway ID. } + t.prevGoAwayID = id + active := len(t.activeStreams) + if active <= 0 { + t.mu.Unlock() + t.Close(connectionErrorf(true, nil, "received goaway and there are no active streams")) + return + } + + var unprocessedStream []*Stream for streamID, stream := range t.activeStreams { if streamID > id && streamID <= upperLimit { // The stream was unprocessed by the server. atomic.StoreUint32(&stream.unprocessed, 1) - t.closeStream(stream, errStreamDrain, false, http2.ErrCodeNo, statusGoAway, nil, false) + unprocessedStream = append(unprocessedStream, stream) } } - t.prevGoAwayID = id - active := len(t.activeStreams) t.mu.Unlock() - if active == 0 { - t.Close(connectionErrorf(true, nil, "received goaway and there are no active streams")) + + // we should not access controlBuf with t.mu held since it will cause deadlock. + // Pls refer to checkForStreamQuota in NewStream, it gets the controlbuf.mu and + // wants to get the t.mu. + for _, stream := range unprocessedStream { + t.closeStream(stream, errStreamDrain, false, http2.ErrCodeNo, statusGoAway, nil, false) } } diff --git a/pkg/remote/trans/nphttp2/grpc/http2_server.go b/pkg/remote/trans/nphttp2/grpc/http2_server.go index c2b84efe13..03320a1e5e 100644 --- a/pkg/remote/trans/nphttp2/grpc/http2_server.go +++ b/pkg/remote/trans/nphttp2/grpc/http2_server.go @@ -50,6 +50,11 @@ import ( "github.com/cloudwego/kitex/pkg/utils" ) +const ( + gracefulShutdownCode = http2.ErrCode(1000) + gracefulShutdownMsg = "graceful shutdown" +) + var ( // ErrIllegalHeaderWrite indicates that setting header is illegal because of // the stream's state. @@ -57,7 +62,8 @@ var ( // ErrHeaderListSizeLimitViolation indicates that the header list size is larger // than the limit set by peer. - ErrHeaderListSizeLimitViolation = errors.New("transport: trying to send header list size larger than the limit set by peer") + ErrHeaderListSizeLimitViolation = errors.New("transport: trying to send header list size larger than the limit set by peer") + errStatusHeaderListSizeLimitViolation = status.Err(codes.Internal, ErrHeaderListSizeLimitViolation.Error()) // errors used for cancelling stream. // the code should be codes.Canceled coz it's NOT returned from remote @@ -67,6 +73,8 @@ var ( errNotReachable = status.New(codes.Canceled, "transport: server not reachable").Err() errMaxAgeClosing = status.New(codes.Canceled, "transport: closing server transport due to maximum connection age").Err() errIdleClosing = status.New(codes.Canceled, "transport: closing server transport due to idleness").Err() + + errGracefulShutdown = status.Err(codes.Unavailable, gracefulShutdownMsg) ) func init() { @@ -81,7 +89,7 @@ type http2Server struct { conn net.Conn loopy *loopyWriter readerDone chan struct{} // sync point to enable testing. - writerDone chan struct{} // sync point to enable testing. + writerDone chan struct{} // denote that the loopyWriter has stopped. remoteAddr net.Addr localAddr net.Addr maxStreamID uint32 // max stream ID ever seen @@ -115,9 +123,11 @@ type http2Server struct { // During this time we don't want to write another first GoAway(with ID 2^31 -1) frame. // Thus call to drain(...) will be a no-op if drainChan is already initialized since draining is // already underway. - drainChan chan struct{} - state transportState - activeStreams map[uint32]*Stream + drainChan chan struct{} + // denote that drainChan has been closed to prevent remote side sending more ping Ack with goAway + drainChanClosed bool + state transportState + activeStreams map[uint32]*Stream // idle is the time instant when the connection went idle. // This is either the beginning of the connection or when the number of // RPCs go down to 0. @@ -279,8 +289,9 @@ func newHTTP2Server(ctx context.Context, conn net.Conn, config *ServerConfig) (_ gofunc.RecoverGoFuncWithInfo(ctx, func() { t.loopy = newLoopyWriter(serverSide, t.framer, t.controlBuf, t.bdpEst) t.loopy.ssGoAwayHandler = t.outgoingGoAwayHandler - if err := t.loopy.run(conn.RemoteAddr().String()); err != nil { - klog.CtxErrorf(ctx, "KITEX: grpc server loopyWriter.run returning, error=%v", err) + runErr := t.loopy.run(conn.RemoteAddr().String()) + if runErr != nil { + klog.CtxErrorf(ctx, "KITEX: grpc server loopyWriter.run returning, error=%v", runErr) } t.conn.Close() close(t.writerDone) @@ -408,7 +419,7 @@ func (t *http2Server) HandleStreams(handle func(*Stream), traceCtx func(context. // it will be codes.Internal error for GRPC // TODO: map http2.StreamError to status.Error? s.cancel(err) - t.closeStream(s, true, se.Code, false) + t.closeStream(s, status.Errorf(codes.Canceled, "transport: ReadFrame encountered http2.StreamError: %v", err), true, se.Code, false) } else { t.controlBuf.put(&cleanupStream{ streamID: se.StreamID, @@ -551,7 +562,7 @@ func (t *http2Server) handleData(f *grpcframe.DataFrame) { } if size > 0 { if err := s.fc.onData(size); err != nil { - t.closeStream(s, true, http2.ErrCodeFlowControl, false) + t.closeStream(s, status.Errorf(codes.Canceled, "transport: inflow control err: %v", err), true, http2.ErrCodeFlowControl, false) return } if f.Header().Flags.Has(http2.FlagDataPadded) { @@ -579,7 +590,11 @@ func (t *http2Server) handleData(f *grpcframe.DataFrame) { func (t *http2Server) handleRSTStream(f *http2.RSTStreamFrame) { // If the stream is not deleted from the transport's active streams map, then do a regular close stream. if s, ok := t.getStream(f); ok { - t.closeStream(s, false, 0, false) + if f.ErrCode == gracefulShutdownCode { + t.closeStream(s, errGracefulShutdown, false, 0, false) + } else { + t.closeStream(s, status.Errorf(codes.Canceled, "transport: RSTStream Frame received with error code: %v", f.ErrCode), false, 0, false) + } return } // If the stream is already deleted from the active streams map, then put a cleanupStream item into controlbuf to delete the stream from loopy writer's established streams map. @@ -626,9 +641,13 @@ const ( func (t *http2Server) handlePing(f *http2.PingFrame) { if f.IsAck() { - if f.Data == goAwayPing.data && t.drainChan != nil { - close(t.drainChan) - return + if f.Data == goAwayPing.data { + t.mu.Lock() + if !t.drainChanClosed { + close(t.drainChan) + t.drainChanClosed = true + } + t.mu.Unlock() } // Maybe it's a BDP ping. if t.bdpEst != nil { @@ -756,7 +775,7 @@ func (t *http2Server) writeHeaderLocked(s *Stream) error { if err != nil { return err } - t.closeStream(s, true, http2.ErrCodeInternal, false) + t.closeStream(s, errStatusHeaderListSizeLimitViolation, true, http2.ErrCodeInternal, false) return ErrHeaderListSizeLimitViolation } return nil @@ -819,7 +838,7 @@ func (t *http2Server) WriteStatus(s *Stream, st *status.Status) error { if err != nil { return err } - t.closeStream(s, true, http2.ErrCodeInternal, false) + t.closeStream(s, errStatusHeaderListSizeLimitViolation, true, http2.ErrCodeInternal, false) return ErrHeaderListSizeLimitViolation } // Send a RST_STREAM after the trailers if the client has not already half-closed. @@ -910,12 +929,12 @@ func (t *http2Server) keepalive() { if val <= 0 { // The connection has been idle for a duration of keepalive.MaxConnectionIdle or more. // Gracefully close the connection. - t.drain(http2.ErrCodeNo, []byte{}) + t.drain(http2.ErrCodeNo, []byte("idleTimeout")) return } idleTimer.Reset(val) case <-ageTimer.C: - t.drain(http2.ErrCodeNo, []byte{}) + t.drain(http2.ErrCodeNo, []byte("ageTimeout")) ageTimer.Reset(t.kp.MaxConnectionAgeGrace) select { case <-ageTimer.C: @@ -963,7 +982,71 @@ func (t *http2Server) keepalive() { // TODO(zhaoq): Now the destruction is not blocked on any pending streams. This // could cause some resource issue. Revisit this later. func (t *http2Server) Close() error { - return t.closeWithErr(nil) + t.mu.Lock() + if t.state == closing { + t.mu.Unlock() + return nil + } + t.state = closing + streams := t.activeStreams + t.activeStreams = nil + t.mu.Unlock() + + finishErr := errGracefulShutdown + finishCh := make(chan struct{}, 1) + activeNums := t.rstActiveStreams(streams, finishErr, gracefulShutdownCode, finishCh) + if activeNums == 0 { + t.closeLoopyWriter(finishErr) + return nil + } + for { + select { + // wait for all the RstStream Frames to be written + case <-finishCh: + activeNums-- + if activeNums == 0 { + t.closeLoopyWriter(finishErr) + return nil + } + // loopyWriter has quited, there is no chance to write the RstStream Frame + case <-t.writerDone: + t.closeLoopyWriter(finishErr) + return nil + } + } +} + +func (t *http2Server) closeLoopyWriter(err error) { + t.controlBuf.finish(err) + close(t.done) + // make use of loopyWriter.run returning to close the connection + // there is no need to close the connection manually + <-t.writerDone +} + +// rstActiveStreams sends RSTStream frames to all active streams. +func (t *http2Server) rstActiveStreams(streams map[uint32]*Stream, cancelErr error, rstCode http2.ErrCode, finishCh chan struct{}) (activeStreams int) { + for _, s := range streams { + oldState := s.swapState(streamDone) + if oldState == streamDone { + // If the stream was already done, continue + continue + } + activeStreams++ + // cancel the downstream + s.cancel(cancelErr) + // send RSTStream Frame to the upstream + t.controlBuf.put(&cleanupStream{ + streamID: s.id, + rst: true, + rstCode: rstCode, + onWrite: func() {}, + onFinishWrite: func() { + finishCh <- struct{}{} + }, + }) + } + return activeStreams } func (t *http2Server) closeWithErr(reason error) error { @@ -976,7 +1059,7 @@ func (t *http2Server) closeWithErr(reason error) error { streams := t.activeStreams t.activeStreams = nil t.mu.Unlock() - t.controlBuf.finish() + t.controlBuf.finish(reason) close(t.done) err := t.conn.Close() @@ -1025,7 +1108,11 @@ func (t *http2Server) finishStream(s *Stream, rst bool, rstCode http2.ErrCode, h } // closeStream clears the footprint of a stream when the stream is not needed any more. -func (t *http2Server) closeStream(s *Stream, rst bool, rstCode http2.ErrCode, eosReceived bool) { +func (t *http2Server) closeStream(s *Stream, err error, rst bool, rstCode http2.ErrCode, eosReceived bool) { + // In case stream sending and receiving are invoked in separate + // goroutines (e.g., bi-directional streaming), cancel needs to be + // called to interrupt the potential blocking on other goroutines. + s.cancel(err) s.swapState(streamDone) t.deleteStream(s, eosReceived) @@ -1046,16 +1133,19 @@ func (t *http2Server) LocalAddr() net.Addr { } func (t *http2Server) Drain() { - t.drain(http2.ErrCodeNo, []byte{}) + t.drain(http2.ErrCodeNo, []byte(gracefulShutdownMsg)) } func (t *http2Server) drain(code http2.ErrCode, debugData []byte) { t.mu.Lock() - defer t.mu.Unlock() if t.drainChan != nil { + t.mu.Unlock() return } t.drainChan = make(chan struct{}) + t.mu.Unlock() + // drain successfully + // should release lock before access controlBuf t.controlBuf.put(&goAway{code: code, debugData: debugData, headsUp: true}) } @@ -1096,7 +1186,7 @@ func (t *http2Server) outgoingGoAwayHandler(g *goAway) (bool, error) { // originated before the GoAway reaches the client. // After getting the ack or timer expiration send out another GoAway this // time with an ID of the max stream server intends to process. - if err := t.framer.WriteGoAway(math.MaxUint32, http2.ErrCodeNo, []byte{}); err != nil { + if err := t.framer.WriteGoAway(math.MaxUint32, http2.ErrCodeNo, g.debugData); err != nil { return false, err } if err := t.framer.WritePing(false, goAwayPing.data); err != nil { @@ -1113,6 +1203,6 @@ func (t *http2Server) outgoingGoAwayHandler(g *goAway) (bool, error) { return } t.controlBuf.put(&goAway{code: g.code, debugData: g.debugData}) - }, gofunc.EmptyInfo) + }, gofunc.NewBasicInfo("", t.conn.RemoteAddr().String())) // we should create a new BasicInfo here return false, nil } diff --git a/pkg/remote/trans/nphttp2/grpc/http_util.go b/pkg/remote/trans/nphttp2/grpc/http_util.go index ecb6bb7ec1..49af9275cf 100644 --- a/pkg/remote/trans/nphttp2/grpc/http_util.go +++ b/pkg/remote/trans/nphttp2/grpc/http_util.go @@ -77,6 +77,7 @@ var ( http2.ErrCodeEnhanceYourCalm: codes.ResourceExhausted, http2.ErrCodeInadequateSecurity: codes.PermissionDenied, http2.ErrCodeHTTP11Required: codes.Internal, + gracefulShutdownCode: codes.Unavailable, } statusCodeConvTab = map[codes.Code]http2.ErrCode{ codes.Internal: http2.ErrCodeInternal, diff --git a/pkg/remote/trans/nphttp2/grpc/transport_test.go b/pkg/remote/trans/nphttp2/grpc/transport_test.go index 5e0f82f482..cc447bc9d9 100644 --- a/pkg/remote/trans/nphttp2/grpc/transport_test.go +++ b/pkg/remote/trans/nphttp2/grpc/transport_test.go @@ -57,6 +57,10 @@ type server struct { conns map[ServerTransport]bool h *testStreamHandler ready chan struct{} + hdlWG sync.WaitGroup + transWG sync.WaitGroup + + srvReady chan struct{} } var ( @@ -77,6 +81,7 @@ func init() { type testStreamHandler struct { t *http2Server + srv *server notify chan struct{} getNotified chan struct{} } @@ -92,6 +97,8 @@ const ( invalidHeaderField delayRead pingpong + + gracefulShutdown ) func (h *testStreamHandler) handleStreamAndNotify(s *Stream) { @@ -292,6 +299,24 @@ func (h *testStreamHandler) handleStreamDelayRead(t *testing.T, s *Stream) { } } +func (h *testStreamHandler) gracefulShutdown(t *testing.T, s *Stream) { + t.Log("run graceful shutdown") + close(h.srv.srvReady) + msg := make([]byte, 5) + num, err := s.Read(msg) + test.Assert(t, err == nil, err) + test.Assert(t, num == 5, num) + test.Assert(t, string(msg) == "hello", string(msg)) + err = h.t.Write(s, nil, msg, &Options{}) + test.Assert(t, err == nil, err) + _, err = s.Read(msg) + test.Assert(t, err != nil, err) + test.Assert(t, strings.Contains(err.Error(), gracefulShutdownMsg), err) + st, ok := status.FromError(err) + test.Assert(t, ok, err) + test.Assert(t, st.Code() == codes.Unavailable, st) +} + // start starts server. Other goroutines should block on s.readyChan for further operations. func (s *server) start(t *testing.T, port int, serverConfig *ServerConfig, ht hType) { // 创建 listener @@ -328,6 +353,7 @@ func (s *server) start(t *testing.T, port int, serverConfig *ServerConfig, ht hT s.conns[transport] = true h := &testStreamHandler{t: transport.(*http2Server)} s.h = h + h.srv = s s.mu.Unlock() switch ht { case notifyCall: @@ -378,12 +404,26 @@ func (s *server) start(t *testing.T, port int, serverConfig *ServerConfig, ht hT }, func(ctx context.Context, method string) context.Context { return ctx }) + case gracefulShutdown: + s.transWG.Add(1) + go func() { + defer s.transWG.Done() + transport.HandleStreams(func(stream *Stream) { + s.hdlWG.Add(1) + go func() { + defer s.hdlWG.Done() + h.gracefulShutdown(t, stream) + }() + }, func(ctx context.Context, method string) context.Context { return ctx }) + }() default: - go transport.HandleStreams(func(s *Stream) { - go h.handleStream(t, s) - }, func(ctx context.Context, method string) context.Context { - return ctx - }) + go func() { + transport.HandleStreams(func(s *Stream) { + go h.handleStream(t, s) + }, func(ctx context.Context, method string) context.Context { + return ctx + }) + }() } return ctx } @@ -422,9 +462,6 @@ func (s *server) wait(t *testing.T, timeout time.Duration) { func (s *server) stop() { ctx, cancel := context.WithTimeout(context.Background(), time.Second) defer cancel() - if err := s.eventLoop.Shutdown(ctx); err != nil { - fmt.Printf("netpoll server exit failed, err=%v", err) - } s.lis.Close() s.mu.Lock() for c := range s.conns { @@ -432,6 +469,40 @@ func (s *server) stop() { } s.conns = nil s.mu.Unlock() + if err := s.eventLoop.Shutdown(ctx); err != nil { + fmt.Printf("netpoll server exit failed, err=%v", err) + } +} + +func (s *server) gracefulShutdown(finishCh chan struct{}) { + ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) + defer cancel() + s.lis.Close() + s.mu.Lock() + for trans := range s.conns { + trans.Drain() + } + s.mu.Unlock() + timeout, _ := ctx.Deadline() + graceTimer := time.NewTimer(time.Until(timeout)) + exitCh := make(chan struct{}) + go func() { + select { + case <-graceTimer.C: + s.mu.Lock() + for trans := range s.conns { + trans.Close() + } + s.mu.Unlock() + return + case <-exitCh: + return + } + }() + s.hdlWG.Wait() + s.transWG.Wait() + close(exitCh) + close(finishCh) } func (s *server) addr() string { @@ -442,7 +513,7 @@ func (s *server) addr() string { } func setUpServerOnly(t *testing.T, port int, serverConfig *ServerConfig, ht hType) *server { - server := &server{startedErr: make(chan error, 1), ready: make(chan struct{})} + server := &server{startedErr: make(chan error, 1), ready: make(chan struct{}), srvReady: make(chan struct{})} go server.start(t, port, serverConfig, ht) server.wait(t, time.Second) return server @@ -510,6 +581,19 @@ func setUpWithNoPingServer(t *testing.T, copts ConnectOptions, connCh chan net.C return tr.(*http2Client) } +func setUpWithOnGoAway(t *testing.T, port int, serverConfig *ServerConfig, ht hType, copts ConnectOptions, onGoAway func(reason GoAwayReason)) (*server, *http2Client) { + server := setUpServerOnly(t, port, serverConfig, ht) + conn, err := netpoll.NewDialer().DialTimeout("tcp", "localhost:"+server.port, time.Second) + if err != nil { + t.Fatalf("failed to dial connection: %v", err) + } + ct, connErr := NewClientTransport(context.Background(), conn.(netpoll.Connection), copts, "", onGoAway, func() {}) + if connErr != nil { + t.Fatalf("failed to create transport: %v", connErr) + } + return server, ct.(*http2Client) +} + // TestInflightStreamClosing ensures that closing in-flight stream // sends status error to concurrent stream reader. func TestInflightStreamClosing(t *testing.T) { diff --git a/pkg/remote/trans/nphttp2/server_handler.go b/pkg/remote/trans/nphttp2/server_handler.go index 2de1006bfe..601689a63a 100644 --- a/pkg/remote/trans/nphttp2/server_handler.go +++ b/pkg/remote/trans/nphttp2/server_handler.go @@ -18,6 +18,7 @@ package nphttp2 import ( "bytes" + "container/list" "context" "errors" "fmt" @@ -25,6 +26,7 @@ import ( "runtime/debug" "strings" "sync" + "sync/atomic" "time" "github.com/cloudwego/netpoll" @@ -62,6 +64,7 @@ func newSvrTransHandler(opt *remote.ServerOption) (*svrTransHandler, error) { opt: opt, svcSearcher: opt.SvcSearcher, codec: grpc.NewGRPCCodec(grpc.WithThriftCodec(opt.PayloadCodec)), + li: list.New(), }, nil } @@ -72,6 +75,10 @@ type svrTransHandler struct { svcSearcher remote.ServiceSearcher inkHdlFunc endpoint.Endpoint codec remote.Codec + + mu sync.Mutex + // maintain all active server transports + li *list.List } var prefaceReadAtMost = func() int { @@ -123,10 +130,11 @@ func (t *svrTransHandler) Read(ctx context.Context, conn net.Conn, msg remote.Me func (t *svrTransHandler) OnRead(ctx context.Context, conn net.Conn) error { svrTrans := ctx.Value(ctxKeySvrTransport).(*SvrTrans) tr := svrTrans.tr - tr.HandleStreams(func(s *grpcTransport.Stream) { + atomic.AddInt32(&svrTrans.handlerNum, 1) gofunc.GoFunc(ctx, func() { t.handleFunc(s, svrTrans, conn) + atomic.AddInt32(&svrTrans.handlerNum, -1) }) }, func(ctx context.Context, method string) context.Context { return ctx @@ -298,11 +306,21 @@ func (t *svrTransHandler) OnMessage(ctx context.Context, args, result remote.Mes type svrTransKey int -const ctxKeySvrTransport svrTransKey = 1 +const ( + ctxKeySvrTransport svrTransKey = 1 + // align with default exitWaitTime + defaultGraceTime time.Duration = 5 * time.Second + // max poll time to check whether all the transports have finished + // A smaller poll time will not affect performance and will speed up our unit tests + defaultMaxPollTime = 50 * time.Millisecond +) type SvrTrans struct { tr grpcTransport.ServerTransport pool *sync.Pool // value is rpcInfo + elem *list.Element + // num of active handlers + handlerNum int32 } // 新连接建立时触发,主要用于服务端,对应 netpoll onPrepare @@ -326,13 +344,22 @@ func (t *svrTransHandler) OnActive(ctx context.Context, conn net.Conn) (context. return ri }, } - ctx = context.WithValue(ctx, ctxKeySvrTransport, &SvrTrans{tr: tr, pool: pool}) + svrTrans := &SvrTrans{tr: tr, pool: pool} + t.mu.Lock() + elem := t.li.PushBack(svrTrans) + t.mu.Unlock() + svrTrans.elem = elem + ctx = context.WithValue(ctx, ctxKeySvrTransport, svrTrans) return ctx, nil } // 连接关闭时回调 func (t *svrTransHandler) OnInactive(ctx context.Context, conn net.Conn) { - tr := ctx.Value(ctxKeySvrTransport).(*SvrTrans).tr + svrTrans := ctx.Value(ctxKeySvrTransport).(*SvrTrans) + tr := svrTrans.tr + t.mu.Lock() + t.li.Remove(svrTrans.elem) + t.mu.Unlock() tr.Close() } @@ -353,6 +380,64 @@ func (t *svrTransHandler) SetInvokeHandleFunc(inkHdlFunc endpoint.Endpoint) { func (t *svrTransHandler) SetPipeline(p *remote.TransPipeline) { } +func (t *svrTransHandler) GracefulShutdown(ctx context.Context) error { + klog.Info("KITEX: gRPC GracefulShutdown starts") + defer func() { + klog.Info("KITEX: gRPC GracefulShutdown ends") + }() + t.mu.Lock() + for elem := t.li.Front(); elem != nil; elem = elem.Next() { + svrTrans := elem.Value.(*SvrTrans) + svrTrans.tr.Drain() + } + t.mu.Unlock() + graceTime, pollTime := parseGraceAndPollTime(ctx) + graceTimer := time.NewTimer(graceTime) + defer graceTimer.Stop() + pollTick := time.NewTicker(pollTime) + defer pollTick.Stop() + for { + select { + case <-pollTick.C: + var activeNums int32 + t.mu.Lock() + for elem := t.li.Front(); elem != nil; elem = elem.Next() { + svrTrans := elem.Value.(*SvrTrans) + activeNums += atomic.LoadInt32(&svrTrans.handlerNum) + } + t.mu.Unlock() + if activeNums == 0 { + return nil + } + case <-graceTimer.C: + klog.Info("KITEX: gRPC triggers Close") + t.mu.Lock() + for elem := t.li.Front(); elem != nil; elem = elem.Next() { + svrTrans := elem.Value.(*SvrTrans) + svrTrans.tr.Close() + } + t.mu.Unlock() + return nil + } + } +} + +func parseGraceAndPollTime(ctx context.Context) (graceTime, pollTime time.Duration) { + graceTime = defaultGraceTime + deadline, ok := ctx.Deadline() + if ok { + if customTime := time.Until(deadline); customTime > 0 { + graceTime = customTime + } + } + + pollTime = graceTime / 10 + if pollTime > defaultMaxPollTime { + pollTime = defaultMaxPollTime + } + return +} + func (t *svrTransHandler) startTracer(ctx context.Context, ri rpcinfo.RPCInfo) context.Context { c := t.opt.TracerCtl.DoStart(ctx, ri) return c diff --git a/pkg/remote/trans/nphttp2/server_handler_test.go b/pkg/remote/trans/nphttp2/server_handler_test.go index 3ae92b50c6..4d38c58447 100644 --- a/pkg/remote/trans/nphttp2/server_handler_test.go +++ b/pkg/remote/trans/nphttp2/server_handler_test.go @@ -341,3 +341,26 @@ func TestSvrTransHandlerProtocolMatch(t *testing.T) { err = th.ProtocolMatch(context.Background(), rawConn) test.Assert(t, err != nil, err) } + +func Test_parseGraceAndPollTime(t *testing.T) { + // without timeout in ctx + ctx := context.Background() + graceTime, pollTime := parseGraceAndPollTime(ctx) + test.Assert(t, graceTime == defaultGraceTime, graceTime) + test.Assert(t, pollTime == defaultMaxPollTime, pollTime) + + // with timeout longer than default timeout + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + graceTime, pollTime = parseGraceAndPollTime(ctx) + test.Assert(t, graceTime > defaultGraceTime, graceTime) + // defaultMaxPollTime is the max poll time + test.Assert(t, pollTime == defaultMaxPollTime, pollTime) + + // with timeout shorter than default timeout + ctx, cancel = context.WithTimeout(context.Background(), 400*time.Millisecond) + defer cancel() + graceTime, pollTime = parseGraceAndPollTime(ctx) + test.Assert(t, graceTime < defaultGraceTime, graceTime) + test.Assert(t, pollTime < defaultMaxPollTime, pollTime) +}