Skip to content

Commit

Permalink
增加读取超时方法
Browse files Browse the repository at this point in the history
  • Loading branch information
吴迎松 committed May 9, 2017
1 parent 7023afa commit 2174fad
Show file tree
Hide file tree
Showing 2 changed files with 54 additions and 34 deletions.
26 changes: 19 additions & 7 deletions tcp_conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"strings"
"sync"
"sync/atomic"
"time"
)

var (
Expand All @@ -21,10 +22,11 @@ type TCPConn struct {
readChan chan Packet
writeChan chan Packet

exitChan chan struct{}
closeOnce sync.Once
exitFlag int32
err error
readDeadline time.Duration
exitChan chan struct{}
closeOnce sync.Once
exitFlag int32
err error
}

func NewTCPConn(conn *net.TCPConn, callback CallBack, protocol Protocol) *TCPConn {
Expand All @@ -35,9 +37,8 @@ func NewTCPConn(conn *net.TCPConn, callback CallBack, protocol Protocol) *TCPCon
readChan: make(chan Packet, readChanSize),
writeChan: make(chan Packet, writeChanSize),
exitChan: make(chan struct{}),
exitFlag: 1,
exitFlag: 0,
}
c.Serve()
return c
}

Expand All @@ -47,6 +48,7 @@ func (c *TCPConn) Serve() {
logger.Println("tcp conn(%v) Serve error, %v ", c.RemoteIP(), r)
}
}()
atomic.StoreInt32(&c.exitFlag, 1)
c.callback.OnConnected(c)
go c.readLoop()
go c.writeLoop()
Expand All @@ -64,9 +66,11 @@ func (c *TCPConn) readLoop() {
case <-c.exitChan:
return
default:
if c.readDeadline > 0 {
c.conn.SetReadDeadline(time.Now().Add(c.readDeadline))
}
p, err := c.protocol.ReadPacket(c.conn)
if err != nil {
// c.callback.OnError(err, c)
return
}
c.readChan <- p
Expand Down Expand Up @@ -167,3 +171,11 @@ func (c *TCPConn) RemoteAddr() string {
func (c *TCPConn) RemoteIP() string {
return strings.Split(c.RemoteAddr(), ":")[0]
}

func (c *TCPConn) setReadDeadline(t time.Duration) error {
if !c.IsClosed() {
return errors.New("conn is running")
}
c.readDeadline = t
return nil
}
62 changes: 35 additions & 27 deletions tcp_server.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ type TCPServer struct {
exitChan chan struct{}

maxPacketSize uint32 //single packet max bytes
deadLine time.Duration //the tcp connection read and write timeout
readDeadline time.Duration //conn read deadline
bucket *TCPConnBucket
}

Expand Down Expand Up @@ -80,12 +80,14 @@ func (srv *TCPServer) Serve(l *net.TCPListener) error {
}
srv.listener.Close()
}()
// go func() {
// for {
// srv.removeClosedTCPConn()
// time.Sleep(time.Millisecond * 10)
// }
// }()

//清理无效连接
go func() {
for {
srv.removeClosedTCPConn()
time.Sleep(time.Millisecond * 10)
}
}()

var tempDelay time.Duration
for {
Expand Down Expand Up @@ -122,7 +124,12 @@ func (srv *TCPServer) newTCPConn(conn *net.TCPConn, callback CallBack, protocol
// if the handler is nil, use srv handler
callback = srv.callback
}
return NewTCPConn(conn, callback, protocol)
c := NewTCPConn(conn, callback, protocol)
if srv.readDeadline > 0 {
log.Println(c.setReadDeadline(srv.readDeadline))
}
c.Serve()
return c
}

//Connect 使用指定的callback和protocol连接其他TCPServer,返回TCPConn
Expand Down Expand Up @@ -151,25 +158,22 @@ func (srv *TCPServer) Close() {
}
}

// func (srv *TCPServer) removeClosedTCPConn() {
// for {
// select {
// case <-srv.exitChan:
// return
// default:
// removeKey := make(map[string]struct{})
// for key, conn := range srv.bucket.GetAll() {
// if conn.IsClosed() {
// removeKey[key] = struct{}{}
// }
// }
// for key := range removeKey {
// srv.bucket.Delete(key)
// }
// time.Sleep(time.Millisecond * 10)
// }
// }
// }
func (srv *TCPServer) removeClosedTCPConn() {
select {
case <-srv.exitChan:
return
default:
removeKey := make(map[string]struct{})
for key, conn := range srv.bucket.GetAll() {
if conn.IsClosed() {
removeKey[key] = struct{}{}
}
}
for key := range removeKey {
srv.bucket.Delete(key)
}
}
}

//GetAllTCPConn 返回所有客户端连接
func (srv *TCPServer) GetAllTCPConn() []*TCPConn {
Expand All @@ -183,3 +187,7 @@ func (srv *TCPServer) GetAllTCPConn() []*TCPConn {
func (srv *TCPServer) GetTCPConn(key string) *TCPConn {
return srv.bucket.Get(key)
}

func (srv *TCPServer) SetReadDeadline(t time.Duration) {
srv.readDeadline = t
}

0 comments on commit 2174fad

Please sign in to comment.