Skip to content
This repository has been archived by the owner on May 22, 2020. It is now read-only.

Commit

Permalink
Major overhaul of Read() logic: cleaner and safer
Browse files Browse the repository at this point in the history
  • Loading branch information
sergeyfrolov committed Feb 11, 2017
1 parent ab361f4 commit 291cc08
Showing 1 changed file with 54 additions and 26 deletions.
80 changes: 54 additions & 26 deletions tapdance/tdConn.go
Original file line number Diff line number Diff line change
Expand Up @@ -247,14 +247,17 @@ func (tdConn *tapdanceConn) read_as(b []byte, caller int) (n int, err error) {
// if MSG_TYPE == DATA:
// 4-length: DATA

var readBytesTotal, totalBytesToRead uint16
var readBytes int
var headerSize, msgLen, magicVal, expectedMagicVal uint16
var readBytesTotal uint16
headerSize := uint16(3)
totalBytesToRead := headerSize

var msgLen uint16
var msgType uint8
headerSize = 3
totalBytesToRead = 3
n = 0

headerIsRead := false

// Read into special buffer, if it is connect/reconnect
var read_buffer []byte
switch caller {
case TD_RECONNECT_CALL: fallthrough
Expand All @@ -263,9 +266,39 @@ func (tdConn *tapdanceConn) read_as(b []byte, caller int) (n int, err error) {
default: panic("tdConn.read_as was called with incorrect caller " + string(caller))
}

// TODO: read only totalBytesToRead?
// This function checks if message type, given particular caller, is appropriate.
// In case it is appropriate - returns nil, otherwise - the error
checkMsgType := func(_msgType uint8, _caller int) (error) {
switch _msgType {
case MSG_RECONNECT:
if _caller == TD_USER_CALL {
return errors.New("Received RECONNECT message in initialized connection")
} else if _caller == TD_INIT_CALL {
return errors.New("Received RECONNECT message instead of INIT!")
}
case MSG_INIT:
if _caller == TD_USER_CALL {
return errors.New("Received INIT message in initialized connection")
}
if _caller == TD_RECONNECT_CALL {
// TODO: will be error eventually
Logger.Warningf("[Flow " + strconv.FormatUint(uint64(tdConn.id), 10) +
"] Got INIT instead of reconnect! Moving on")
}
case MSG_DATA:
if _caller == TD_RECONNECT_CALL || _caller == TD_INIT_CALL {
return errors.New("Received DATA message in uninitialized connection")
}
case MSG_CLOSE:
// always appropriate
default:
return errors.New("Unknown message #" + strconv.FormatUint(uint64(_msgType), 10))
}
return nil
}

for readBytesTotal < totalBytesToRead {
readBytes, err = tdConn.ztlsConn.Read(read_buffer[readBytesTotal:])
readBytes, err = tdConn.ztlsConn.Read(read_buffer[readBytesTotal:totalBytesToRead])
if caller == TD_USER_CALL && atomic.LoadInt32(&tdConn.reconnecting) != 0 {
tdConn.awaitReconnection()
} else if err != nil {
Expand All @@ -284,32 +317,30 @@ func (tdConn *tapdanceConn) read_as(b []byte, caller int) (n int, err error) {
}
}
readBytesTotal += uint16(readBytes)
if readBytesTotal >= headerSize && totalBytesToRead == headerSize {
// once we read msg_len, add it to totalBytesToRead

if readBytesTotal >= headerSize && !headerIsRead {
// Once we read the header
headerIsRead = true

// Check if the message type is appropriate
msgType = read_buffer[0]
err = checkMsgType(msgType, caller)
if err != nil {
return
}

// Add msgLen to totalBytesToRead
msgLen = binary.BigEndian.Uint16(read_buffer[1:3])
totalBytesToRead = headerSize + msgLen

}
}

// Process actual message
switch msgType {
case MSG_RECONNECT:
if caller == TD_USER_CALL {
err = errors.New("Received RECONNECT message in initialized connection")
} else if caller == TD_INIT_CALL {
err = errors.New("Received RECONNECT message instead of INIT!")
}
fallthrough
case MSG_INIT:
if caller == TD_USER_CALL {
err = errors.New("Received INIT message in initialized connection")
}
if caller == TD_RECONNECT_CALL && msgType == MSG_INIT {
// TODO: will be error eventually
Logger.Warningf("[Flow " + strconv.FormatUint(uint64(tdConn.id), 10) +
"] Got INIT instead of reconnect! Moving on")
}
var magicVal, expectedMagicVal uint16
magicVal = binary.BigEndian.Uint16(read_buffer[3:5])
expectedMagicVal = uint16(0x2a75)
if magicVal != expectedMagicVal {
Expand All @@ -320,7 +351,6 @@ func (tdConn *tapdanceConn) read_as(b []byte, caller int) (n int, err error) {
}
Logger.Infof("[Flow " + strconv.FormatUint(uint64(tdConn.id), 10) +
"] Successfully connected to Tapdance Station!")
// TODO: copy extra bytes into shared buffer
case MSG_DATA:
n = int(readBytesTotal - headerSize)
copy(b, read_buffer[headerSize:readBytesTotal])
Expand All @@ -330,8 +360,6 @@ func (tdConn *tapdanceConn) read_as(b []byte, caller int) (n int, err error) {
err = errors.New("MSG_CLOSE")
Logger.Infof("[Flow " + strconv.FormatUint(uint64(tdConn.id), 10) +
"] received MSG_CLOSE")
default:
err = errors.New("Unknown message #" + strconv.FormatUint(uint64(msgType), 10))
}
return
}
Expand Down

0 comments on commit 291cc08

Please sign in to comment.