diff --git a/connack.go b/connack.go new file mode 100644 index 0000000..0bad59a --- /dev/null +++ b/connack.go @@ -0,0 +1,152 @@ +// Copyright (c) 2014 Dataence, LLC. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package message + +import "fmt" + +// The CONNACK Packet is the packet sent by the Server in response to a CONNECT Packet +// received from a Client. The first packet sent from the Server to the Client MUST +// be a CONNACK Packet [MQTT-3.2.0-1]. +// +// If the Client does not receive a CONNACK Packet from the Server within a reasonable +// amount of time, the Client SHOULD close the Network Connection. A "reasonable" amount +// of time depends on the type of application and the communications infrastructure. +type ConnackMessage struct { + header + + sessionPresent bool + returnCode ConnackCode +} + +var _ Message = (*ConnackMessage)(nil) + +// NewConnackMessage creates a new CONNACK message +func NewConnackMessage() *ConnackMessage { + msg := &ConnackMessage{} + msg.SetType(CONNACK) + + return msg +} + +// String returns a string representation of the CONNACK message +func (this ConnackMessage) String() string { + return fmt.Sprintf("%v\nSession Present: %t\nReturn code: %v\n", + this.header, this.sessionPresent, this.returnCode) +} + +// SessionPresent returns the session present flag value +func (this *ConnackMessage) SessionPresent() bool { + return this.sessionPresent +} + +// SetSessionPresent sets the value of the session present flag +func (this *ConnackMessage) SetSessionPresent(v bool) { + if v { + this.sessionPresent = true + } else { + this.sessionPresent = false + } +} + +// ReturnCode returns the return code received for the CONNECT message. The return +// type is an error +func (this *ConnackMessage) ReturnCode() ConnackCode { + return this.returnCode +} + +func (this *ConnackMessage) SetReturnCode(ret ConnackCode) { + this.returnCode = ret +} + +func (this *ConnackMessage) Len() int { + ml := this.msglen() + + if err := this.SetRemainingLength(int32(ml)); err != nil { + return 0 + } + + return this.header.msglen() + ml +} + +func (this *ConnackMessage) Decode(src []byte) (int, error) { + total := 0 + + n, err := this.header.decode(src) + total += n + if err != nil { + return total, err + } + + b := src[total] + + if b&254 != 0 { + return 0, fmt.Errorf("connack/Decode: Bits 7-1 in Connack Acknowledge Flags byte (1) are not 0") + } + + this.sessionPresent = b&0x1 == 1 + total++ + + b = src[total] + + // Read return code + if b > 5 { + return 0, fmt.Errorf("connack/Decode: Invalid CONNACK return code (%d)", b) + } + + this.returnCode = ConnackCode(b) + total++ + + return total, nil +} + +func (this *ConnackMessage) Encode(dst []byte) (int, error) { + // CONNACK remaining length fixed at 2 bytes + hl := this.header.msglen() + ml := this.msglen() + + if len(dst) < hl+ml { + return 0, fmt.Errorf("connack/Encode: Insufficient buffer size. Expecting %d, got %d.", hl+ml, len(dst)) + } + + if err := this.SetRemainingLength(int32(ml)); err != nil { + return 0, err + } + + total := 0 + + n, err := this.header.encode(dst[total:]) + total += n + if err != nil { + return 0, err + } + + if this.sessionPresent { + dst[total] = 1 + } + total++ + + if this.returnCode > 5 { + return total, fmt.Errorf("connack/Encode: Invalid CONNACK return code (%d)", this.returnCode) + } + + dst[total] = this.returnCode.Value() + total++ + + return total, nil +} + +func (this *ConnackMessage) msglen() int { + return 2 +} diff --git a/connack_test.go b/connack_test.go new file mode 100644 index 0000000..061471b --- /dev/null +++ b/connack_test.go @@ -0,0 +1,160 @@ +// Copyright (c) 2014 Dataence, LLC. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package message + +import ( + "testing" + + "github.com/dataence/assert" +) + +func TestConnackMessageFields(t *testing.T) { + msg := NewConnackMessage() + + msg.SetSessionPresent(true) + assert.True(t, true, msg.SessionPresent(), "Error setting session present flag.") + + msg.SetSessionPresent(false) + assert.False(t, true, msg.SessionPresent(), "Error setting session present flag.") + + msg.SetReturnCode(ConnectionAccepted) + assert.Equal(t, true, ConnectionAccepted, msg.ReturnCode(), "Error setting return code.") +} + +func TestConnackMessageDecode(t *testing.T) { + msgBytes := []byte{ + byte(CONNACK << 4), + 2, + 0, // session not present + 0, // connection accepted + } + + msg := NewConnackMessage() + + n, err := msg.Decode(msgBytes) + + assert.NoError(t, true, err, "Error decoding message.") + assert.Equal(t, true, len(msgBytes), n, "Error decoding message.") + assert.False(t, true, msg.SessionPresent(), "Error decoding session present flag.") + assert.Equal(t, true, ConnectionAccepted, msg.ReturnCode(), "Error decoding return code.") +} + +// testing wrong message length +func TestConnackMessageDecode2(t *testing.T) { + msgBytes := []byte{ + byte(CONNACK << 4), + 3, + 0, // session not present + 0, // connection accepted + } + + msg := NewConnackMessage() + + _, err := msg.Decode(msgBytes) + assert.Error(t, true, err, "Error decoding message.") +} + +// testing wrong message size +func TestConnackMessageDecode3(t *testing.T) { + msgBytes := []byte{ + byte(CONNACK << 4), + 2, + 0, // session not present + } + + msg := NewConnackMessage() + + _, err := msg.Decode(msgBytes) + assert.Error(t, true, err, "Error decoding message.") +} + +// testing wrong reserve bits +func TestConnackMessageDecode4(t *testing.T) { + msgBytes := []byte{ + byte(CONNACK << 4), + 2, + 64, // <- wrong size + 0, // connection accepted + } + + msg := NewConnackMessage() + + _, err := msg.Decode(msgBytes) + assert.Error(t, true, err, "Error decoding message.") +} + +// testing invalid return code +func TestConnackMessageDecode5(t *testing.T) { + msgBytes := []byte{ + byte(CONNACK << 4), + 2, + 0, + 6, // <- wrong code + } + + msg := NewConnackMessage() + + _, err := msg.Decode(msgBytes) + assert.Error(t, true, err, "Error decoding message.") +} + +func TestConnackMessageEncode(t *testing.T) { + msgBytes := []byte{ + byte(CONNACK << 4), + 2, + 1, // session present + 0, // connection accepted + } + + msg := NewConnackMessage() + msg.SetReturnCode(ConnectionAccepted) + msg.SetSessionPresent(true) + + dst := make([]byte, 10) + n, err := msg.Encode(dst) + + assert.NoError(t, true, err, "Error decoding message.") + assert.Equal(t, true, len(msgBytes), n, "Error encoding message.") + assert.Equal(t, true, msgBytes, dst[:n], "Error encoding connack message.") +} + +// test to ensure encoding and decoding are the same +// decode, encode, and decode again +func TestConnackDecodeEncodeEquiv(t *testing.T) { + msgBytes := []byte{ + byte(CONNACK << 4), + 2, + 0, // session not present + 0, // connection accepted + } + + msg := NewConnackMessage() + n, err := msg.Decode(msgBytes) + + assert.NoError(t, true, err, "Error decoding message.") + assert.Equal(t, true, len(msgBytes), n, "Error decoding message.") + + dst := make([]byte, 100) + n2, err := msg.Encode(dst) + + assert.NoError(t, true, err, "Error decoding message.") + assert.Equal(t, true, len(msgBytes), n2, "Error decoding message.") + assert.Equal(t, true, msgBytes, dst[:n2], "Error decoding message.") + + n3, err := msg.Decode(dst) + + assert.NoError(t, true, err, "Error decoding message.") + assert.Equal(t, true, len(msgBytes), n3, "Error decoding message.") +} diff --git a/connackcode.go b/connackcode.go new file mode 100644 index 0000000..f43c5f6 --- /dev/null +++ b/connackcode.go @@ -0,0 +1,89 @@ +// Copyright (c) 2014 Dataence, LLC. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package message + +// ConnackCode is the type representing the return code in the CONNACK message, +// returned after the initial CONNECT message +type ConnackCode byte + +const ( + // Connection accepted + ConnectionAccepted ConnackCode = iota + + // The Server does not support the level of the MQTT protocol requested by the Client + ErrInvalidProtocolVersion + + // The Client identifier is correct UTF-8 but not allowed by the server + ErrIdentifierRejected + + // The Network Connection has been made but the MQTT service is unavailable + ErrServerUnavailable + + // The data in the user name or password is malformed + ErrBadUsernameOrPassword + + // The Client is not authorized to connect + ErrNotAuthorized +) + +// Value returns the value of the ConnackCode, which is just the byte representation +func (this ConnackCode) Value() byte { + return byte(this) +} + +// Desc returns the description of the ConnackCode +func (this ConnackCode) Desc() string { + switch this { + case 0: + return "Connection accepted" + case 1: + return "The Server does not support the level of the MQTT protocol requested by the Client" + case 2: + return "The Client identifier is correct UTF-8 but not allowed by the server" + case 3: + return "The Network Connection has been made but the MQTT service is unavailable" + case 4: + return "The data in the user name or password is malformed" + case 5: + return "The Client is not authorized to connect" + } + + return "" +} + +// Valid checks to see if the ConnackCode is valid. Currently valid codes are <= 5 +func (this ConnackCode) Valid() bool { + return this <= 5 +} + +// Error returns the corresonding error string for the ConnackCode +func (this ConnackCode) Error() string { + switch this { + case 0: + return "Connection accepted" + case 1: + return "Connection Refused, unacceptable protocol version" + case 2: + return "Connection Refused, identifier rejected" + case 3: + return "Connection Refused, Server unavailable" + case 4: + return "Connection Refused, bad user name or password" + case 5: + return "Connection Refused, not authorized" + } + + return "Unknown error" +} diff --git a/connect.go b/connect.go new file mode 100644 index 0000000..3423366 --- /dev/null +++ b/connect.go @@ -0,0 +1,573 @@ +// Copyright (c) 2014 Dataence, LLC. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package message + +import ( + "encoding/binary" + "fmt" +) + +// After a Network Connection is established by a Client to a Server, the first Packet +// sent from the Client to the Server MUST be a CONNECT Packet [MQTT-3.1.0-1]. +// +// A Client can only send the CONNECT Packet once over a Network Connection. The Server +// MUST process a second CONNECT Packet sent from a Client as a protocol violation and +// disconnect the Client [MQTT-3.1.0-2]. See section 4.8 for information about +// handling errors. +type ConnectMessage struct { + header + + // 7: username flag + // 6: password flag + // 5: will retain + // 4-3: will QoS + // 2: will flag + // 1: clean session + // 0: reserved + connectFlags byte + + version byte + + keepAlive uint16 + + protoName, + clientId, + willTopic, + willMessage, + username, + password []byte +} + +var _ Message = (*ConnectMessage)(nil) + +// NewConnectMessage creates a new CONNECT message. +func NewConnectMessage() *ConnectMessage { + msg := &ConnectMessage{} + msg.SetType(CONNECT) + + return msg +} + +// String returns a string representation of the CONNECT message +func (this ConnectMessage) String() string { + return fmt.Sprintf("%v\nConnect Flags: %08b\nVersion: %d\nKeepAlive: %d\nClient ID: %s\nWill Topic: %s\nWill Message: %s\nUsername: %s\nPassword: %s\n", + this.header, + this.connectFlags, + this.Version(), + this.KeepAlive(), + this.ClientId(), + this.WillTopic(), + this.WillMessage(), + this.Username(), + this.Password(), + ) +} + +// Version returns the the 8 bit unsigned value that represents the revision level +// of the protocol used by the Client. The value of the Protocol Level field for +// the version 3.1.1 of the protocol is 4 (0x04). +func (this *ConnectMessage) Version() byte { + return this.version +} + +// SetVersion sets the version value of the CONNECT message +func (this *ConnectMessage) SetVersion(v byte) error { + if _, ok := SupportedVersions[v]; !ok { + return fmt.Errorf("connect/SetVersion: Invalid version number %d", v) + } + + this.version = v + return nil +} + +// CleanSession returns the bit that specifies the handling of the Session state. +// The Client and Server can store Session state to enable reliable messaging to +// continue across a sequence of Network Connections. This bit is used to control +// the lifetime of the Session state. +func (this *ConnectMessage) CleanSession() bool { + return ((this.connectFlags >> 1) & 0x1) == 1 +} + +// SetCleanSession sets the bit that specifies the handling of the Session state. +func (this *ConnectMessage) SetCleanSession(v bool) { + if v { + this.connectFlags |= 0x2 // 00000010 + } else { + this.connectFlags &= 253 // 11111101 + } +} + +// WillFlag returns the bit that specifies whether a Will Message should be stored +// on the server. If the Will Flag is set to 1 this indicates that, if the Connect +// request is accepted, a Will Message MUST be stored on the Server and associated +// with the Network Connection. +func (this *ConnectMessage) WillFlag() bool { + return ((this.connectFlags >> 2) & 0x1) == 1 +} + +// SetWillFlag sets the bit that specifies whether a Will Message should be stored +// on the server. +func (this *ConnectMessage) SetWillFlag(v bool) { + if v { + this.connectFlags |= 0x4 // 00000100 + } else { + this.connectFlags &= 251 // 11111011 + } +} + +// WillQos returns the two bits that specify the QoS level to be used when publishing +// the Will Message. +func (this *ConnectMessage) WillQos() byte { + return (this.connectFlags >> 3) & 0x3 +} + +// SetWillQos sets the two bits that specify the QoS level to be used when publishing +// the Will Message. +func (this *ConnectMessage) SetWillQos(qos byte) error { + if qos != QosAtMostOnce && qos != QosAtLeastOnce && qos != QosExactlyOnce { + return fmt.Errorf("connect/SetWillQos: Invalid QoS level %d", qos) + } + + this.connectFlags = (this.connectFlags & 231) | (qos << 3) // 231 = 11100111 + return nil +} + +// WillRetain returns the bit specifies if the Will Message is to be Retained when it +// is published. +func (this *ConnectMessage) WillRetain() bool { + return ((this.connectFlags >> 5) & 0x1) == 1 +} + +// SetWillRetain sets the bit specifies if the Will Message is to be Retained when it +// is published. +func (this *ConnectMessage) SetWillRetain(v bool) { + if v { + this.connectFlags |= 32 // 00100000 + } else { + this.connectFlags &= 223 // 11011111 + } +} + +// UsernameFlag returns the bit that specifies whether a user name is present in the +// payload. +func (this *ConnectMessage) UsernameFlag() bool { + return ((this.connectFlags >> 7) & 0x1) == 1 +} + +// SetUsernameFlag sets the bit that specifies whether a user name is present in the +// payload. +func (this *ConnectMessage) SetUsernameFlag(v bool) { + if v { + this.connectFlags |= 128 // 10000000 + } else { + this.connectFlags &= 127 // 01111111 + } +} + +// PasswordFlag returns the bit that specifies whether a password is present in the +// payload. +func (this *ConnectMessage) PasswordFlag() bool { + return ((this.connectFlags >> 6) & 0x1) == 1 +} + +// SetPasswordFlag sets the bit that specifies whether a password is present in the +// payload. +func (this *ConnectMessage) SetPasswordFlag(v bool) { + if v { + this.connectFlags |= 64 // 01000000 + } else { + this.connectFlags &= 191 // 10111111 + } +} + +// KeepAlive returns a time interval measured in seconds. Expressed as a 16-bit word, +// it is the maximum time interval that is permitted to elapse between the point at +// which the Client finishes transmitting one Control Packet and the point it starts +// sending the next. +func (this *ConnectMessage) KeepAlive() uint16 { + return this.keepAlive +} + +// SetKeepAlive sets the time interval in which the server should keep the connection +// alive. +func (this *ConnectMessage) SetKeepAlive(v uint16) { + this.keepAlive = v +} + +// ClientId returns an ID that identifies the Client to the Server. Each Client +// connecting to the Server has a unique ClientId. The ClientId MUST be used by +// Clients and by Servers to identify state that they hold relating to this MQTT +// Session between the Client and the Server +func (this *ConnectMessage) ClientId() []byte { + return this.clientId +} + +// SetClientId sets an ID that identifies the Client to the Server. +func (this *ConnectMessage) SetClientId(v []byte) error { + if len(v) > 0 && !ValidClientId(v) { + return ErrIdentifierRejected + } + + this.clientId = v + return nil +} + +// WillTopic returns the topic in which the Will Message should be published to. +// If the Will Flag is set to 1, the Will Topic must be in the payload. +func (this *ConnectMessage) WillTopic() []byte { + return this.willTopic +} + +// SetWillTopic sets the topic in which the Will Message should be published to. +func (this *ConnectMessage) SetWillTopic(v []byte) { + this.willTopic = v + + if len(v) > 0 { + this.SetWillFlag(true) + } else if len(this.willMessage) == 0 { + this.SetWillFlag(false) + } +} + +// WillMessage returns the Will Message that is to be published to the Will Topic. +func (this *ConnectMessage) WillMessage() []byte { + return this.willMessage +} + +// SetWillMessage sets the Will Message that is to be published to the Will Topic. +func (this *ConnectMessage) SetWillMessage(v []byte) { + this.willMessage = v + + if len(v) > 0 { + this.SetWillFlag(true) + } else if len(this.willTopic) == 0 { + this.SetWillFlag(false) + } +} + +// Username returns the username from the payload. If the User Name Flag is set to 1, +// this must be in the payload. It can be used by the Server for authentication and +// authorization. +func (this *ConnectMessage) Username() []byte { + return this.username +} + +// SetUsername sets the username for authentication. +func (this *ConnectMessage) SetUsername(v []byte) { + this.username = v + + if len(v) > 0 { + this.SetUsernameFlag(true) + } else { + this.SetUsernameFlag(false) + } +} + +// Password returns the password from the payload. If the Password Flag is set to 1, +// this must be in the payload. It can be used by the Server for authentication and +// authorization. +func (this *ConnectMessage) Password() []byte { + return this.password +} + +// SetPassword sets the username for authentication. +func (this *ConnectMessage) SetPassword(v []byte) { + this.password = v + + if len(v) > 0 { + this.SetPasswordFlag(true) + } else { + this.SetPasswordFlag(false) + } +} + +func (this *ConnectMessage) Len() int { + ml := this.msglen() + + if err := this.SetRemainingLength(int32(ml)); err != nil { + return 0 + } + + return this.header.msglen() + ml +} + +// For the CONNECT message, the error returned could be a ConnackReturnCode, so +// be sure to check that. Otherwise it's a generic error. If a generic error is +// returned, this Message should be considered invalid. +// +// Caller should call ValidConnackError(err) to see if the returned error is +// a Connack error. If so, caller should send the Client back the corresponding +// CONNACK message. +func (this *ConnectMessage) Decode(src []byte) (int, error) { + total := 0 + + n, err := this.header.decode(src[total:]) + if err != nil { + return total + n, err + } + total += n + + if n, err = this.decodeMessage(src[total:]); err != nil { + return total + n, err + } + total += n + + return total, nil +} + +func (this *ConnectMessage) Encode(dst []byte) (int, error) { + if this.Type() != CONNECT { + return 0, fmt.Errorf("connect/Encode: Invalid message type. Expecting %d, got %d", CONNECT, this.Type()) + } + + _, ok := SupportedVersions[this.version] + if !ok { + return 0, ErrInvalidProtocolVersion + } + + hl := this.header.msglen() + ml := this.msglen() + + if len(dst) < hl+ml { + return 0, fmt.Errorf("connect/Encode: Insufficient buffer size. Expecting %d, got %d.", hl+ml, len(dst)) + } + + if err := this.SetRemainingLength(int32(ml)); err != nil { + return 0, err + } + + total := 0 + + n, err := this.header.encode(dst[total:]) + total += n + if err != nil { + return total, err + } + + n, err = this.encodeMessage(dst[total:]) + total += n + if err != nil { + return total, err + } + + return total, nil +} + +func (this *ConnectMessage) encodeMessage(dst []byte) (int, error) { + total := 0 + + n, err := writeLPBytes(dst[total:], []byte(SupportedVersions[this.version])) + total += n + if err != nil { + return total, err + } + + dst[total] = this.version + total += 1 + + dst[total] = this.connectFlags + total += 1 + + binary.BigEndian.PutUint16(dst[total:], this.keepAlive) + total += 2 + + n, err = writeLPBytes(dst[total:], this.clientId) + total += n + if err != nil { + return total, err + } + + if this.WillFlag() { + n, err = writeLPBytes(dst[total:], this.willTopic) + total += n + if err != nil { + return total, err + } + + n, err = writeLPBytes(dst[total:], this.willMessage) + total += n + if err != nil { + return total, err + } + } + + // According to the 3.1 spec, it's possible that the usernameFlag is set, + // but the username string is missing. + if this.UsernameFlag() && len(this.username) > 0 { + n, err = writeLPBytes(dst[total:], this.username) + total += n + if err != nil { + return total, err + } + } + + // According to the 3.1 spec, it's possible that the passwordFlag is set, + // but the password string is missing. + if this.PasswordFlag() && len(this.password) > 0 { + n, err = writeLPBytes(dst[total:], this.password) + total += n + if err != nil { + return total, err + } + } + + return total, nil +} + +func (this *ConnectMessage) decodeMessage(src []byte) (int, error) { + var err error + n, total := 0, 0 + + this.protoName, n, err = readLPBytes(src[total:]) + total += n + if err != nil { + return total, err + } + + this.version = src[total] + total++ + + if verstr, ok := SupportedVersions[this.version]; !ok { + return total, ErrInvalidProtocolVersion + } else if verstr != string(this.protoName) { + return total, ErrInvalidProtocolVersion + } + + this.connectFlags = src[total] + total++ + + if this.connectFlags&0x1 != 0 { + return total, fmt.Errorf("connect/decodeMessage: Connect Flags reserved bit 0 is not 0") + } + + if this.WillQos() > QosExactlyOnce { + return total, fmt.Errorf("connect/decodeMessage: Invalid QoS level (%d) for %s message", this.WillQos(), this.Name()) + } + + if !this.WillFlag() && (this.WillRetain() || this.WillQos() != QosAtMostOnce) { + return total, fmt.Errorf("connect/decodeMessage: Protocol violation: If the Will Flag (%t) is set to 0 the Will QoS (%d) and Will Retain (%t) fields MUST be set to zero", this.WillFlag(), this.WillQos(), this.WillRetain()) + } + + if this.UsernameFlag() && !this.PasswordFlag() { + return total, fmt.Errorf("connect/decodeMessage: Username flag is set but Password flag is not set") + } + + if len(src[total:]) < 2 { + return 0, fmt.Errorf("connect/decodeMessage: Insufficient buffer size. Expecting %d, got %d.", 2, len(src[total:])) + } + + this.keepAlive = binary.BigEndian.Uint16(src[total:]) + total += 2 + + this.clientId, n, err = readLPBytes(src[total:]) + total += n + if err != nil { + return total, err + } + + // If the Client supplies a zero-byte ClientId, the Client MUST also set CleanSession to 1 + if len(this.clientId) == 0 && !this.CleanSession() { + return total, ErrIdentifierRejected + } + + // The ClientId must contain only characters 0-9, a-z, and A-Z + // We also support ClientId longer than 23 encoded bytes + // We do not support ClientId outside of the above characters + if len(this.clientId) > 0 && !ValidClientId(this.clientId) { + return total, ErrIdentifierRejected + } + + if this.WillFlag() { + this.willTopic, n, err = readLPBytes(src[total:]) + total += n + if err != nil { + return total, err + } + + this.willMessage, n, err = readLPBytes(src[total:]) + total += n + if err != nil { + return total, err + } + } + + // According to the 3.1 spec, it's possible that the passwordFlag is set, + // but the password string is missing. + if this.UsernameFlag() && len(src[total:]) > 0 { + this.username, n, err = readLPBytes(src[total:]) + total += n + if err != nil { + return total, err + } + } + + // According to the 3.1 spec, it's possible that the passwordFlag is set, + // but the password string is missing. + if this.PasswordFlag() && len(src[total:]) > 0 { + this.password, n, err = readLPBytes(src[total:]) + total += n + if err != nil { + return total, err + } + } + + /* + if len(src[total:]) > 0 { + return total, fmt.Errorf("connect/decodeMessage: Invalid buffer size. Still has %d bytes at the end.", len(src[total:])) + } + */ + + return total, nil +} + +func (this *ConnectMessage) msglen() int { + total := 0 + + verstr, ok := SupportedVersions[this.version] + if !ok { + return total + } + + // 2 bytes protocol name length + // n bytes protocol name + // 1 byte protocol version + // 1 byte connect flags + // 2 bytes keep alive timer + total += 2 + len(verstr) + 1 + 1 + 2 + + // Add the clientID length, 2 is the length prefix + total += 2 + len(this.clientId) + + // Add the will topic and will message length, and the length prefixes + if this.WillFlag() { + total += 2 + len(this.willTopic) + 2 + len(this.willMessage) + } + + // Add the username length + // According to the 3.1 spec, it's possible that the usernameFlag is set, + // but the user name string is missing. + if this.UsernameFlag() && len(this.username) > 0 { + total += 2 + len(this.username) + } + + // Add the password length + // According to the 3.1 spec, it's possible that the passwordFlag is set, + // but the password string is missing. + if this.PasswordFlag() && len(this.password) > 0 { + total += 2 + len(this.password) + } + + return total +} diff --git a/connect_test.go b/connect_test.go new file mode 100644 index 0000000..257fe27 --- /dev/null +++ b/connect_test.go @@ -0,0 +1,366 @@ +// Copyright (c) 2014 Dataence, LLC. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package message + +import ( + "testing" + + "github.com/dataence/assert" +) + +func TestConnectMessageFields(t *testing.T) { + msg := NewConnectMessage() + + err := msg.SetVersion(0x3) + assert.NoError(t, false, err, "Error setting message version.") + + assert.Equal(t, false, 0x3, msg.Version(), "Incorrect version number") + + err = msg.SetVersion(0x5) + assert.Error(t, false, err) + + msg.SetCleanSession(true) + assert.True(t, false, msg.CleanSession(), "Error setting clean session flag.") + + msg.SetCleanSession(false) + assert.False(t, false, msg.CleanSession(), "Error setting clean session flag.") + + msg.SetWillFlag(true) + assert.True(t, false, msg.WillFlag(), "Error setting will flag.") + + msg.SetWillFlag(false) + assert.False(t, false, msg.WillFlag(), "Error setting will flag.") + + msg.SetWillRetain(true) + assert.True(t, false, msg.WillRetain(), "Error setting will retain.") + + msg.SetWillRetain(false) + assert.False(t, false, msg.WillRetain(), "Error setting will retain.") + + msg.SetPasswordFlag(true) + assert.True(t, false, msg.PasswordFlag(), "Error setting password flag.") + + msg.SetPasswordFlag(false) + assert.False(t, false, msg.PasswordFlag(), "Error setting password flag.") + + msg.SetUsernameFlag(true) + assert.True(t, false, msg.UsernameFlag(), "Error setting username flag.") + + msg.SetUsernameFlag(false) + assert.False(t, false, msg.UsernameFlag(), "Error setting username flag.") + + msg.SetWillQos(1) + assert.Equal(t, false, 1, msg.WillQos(), "Error setting will QoS.") + + err = msg.SetWillQos(4) + assert.Error(t, false, err) + + err = msg.SetClientId([]byte("j0j0jfajf02j0asdjf")) + assert.NoError(t, false, err, "Error setting client ID") + + assert.Equal(t, false, "j0j0jfajf02j0asdjf", string(msg.ClientId()), "Error setting client ID.") + + err = msg.SetClientId([]byte("this is no good")) + assert.Error(t, false, err) + + msg.SetWillTopic([]byte("willtopic")) + assert.Equal(t, false, "willtopic", string(msg.WillTopic()), "Error setting will topic.") + + assert.True(t, false, msg.WillFlag(), "Error setting will flag.") + + msg.SetWillTopic([]byte("")) + assert.Equal(t, false, "", string(msg.WillTopic()), "Error setting will topic.") + + assert.False(t, false, msg.WillFlag(), "Error setting will flag.") + + msg.SetWillMessage([]byte("this is a will message")) + assert.Equal(t, false, "this is a will message", string(msg.WillMessage()), "Error setting will message.") + + assert.True(t, false, msg.WillFlag(), "Error setting will flag.") + + msg.SetWillMessage([]byte("")) + assert.Equal(t, false, "", string(msg.WillMessage()), "Error setting will topic.") + + assert.False(t, false, msg.WillFlag(), "Error setting will flag.") + + msg.SetWillTopic([]byte("willtopic")) + msg.SetWillMessage([]byte("this is a will message")) + msg.SetWillTopic([]byte("")) + assert.True(t, false, msg.WillFlag(), "Error setting will topic.") + + msg.SetUsername([]byte("myname")) + assert.Equal(t, false, "myname", string(msg.Username()), "Error setting will message.") + + assert.True(t, false, msg.UsernameFlag(), "Error setting will flag.") + + msg.SetUsername([]byte("")) + assert.Equal(t, false, "", string(msg.Username()), "Error setting will message.") + + assert.False(t, false, msg.UsernameFlag(), "Error setting will flag.") + + msg.SetPassword([]byte("myname")) + assert.Equal(t, false, "myname", string(msg.Password()), "Error setting will message.") + + assert.True(t, false, msg.PasswordFlag(), "Error setting will flag.") + + msg.SetPassword([]byte("")) + assert.Equal(t, false, "", string(msg.Password()), "Error setting will message.") + + assert.False(t, false, msg.PasswordFlag(), "Error setting will flag.") +} + +func TestConnectMessageDecode(t *testing.T) { + msgBytes := []byte{ + byte(CONNECT << 4), + 60, + 0, // Length MSB (0) + 4, // Length LSB (4) + 'M', 'Q', 'T', 'T', + 4, // Protocol level 4 + 206, // connect flags 11001110, will QoS = 01 + 0, // Keep Alive MSB (0) + 10, // Keep Alive LSB (10) + 0, // Client ID MSB (0) + 7, // Client ID LSB (7) + 's', 'u', 'r', 'g', 'e', 'm', 'q', + 0, // Will Topic MSB (0) + 4, // Will Topic LSB (4) + 'w', 'i', 'l', 'l', + 0, // Will Message MSB (0) + 12, // Will Message LSB (12) + 's', 'e', 'n', 'd', ' ', 'm', 'e', ' ', 'h', 'o', 'm', 'e', + 0, // Username ID MSB (0) + 7, // Username ID LSB (7) + 's', 'u', 'r', 'g', 'e', 'm', 'q', + 0, // Password ID MSB (0) + 10, // Password ID LSB (10) + 'v', 'e', 'r', 'y', 's', 'e', 'c', 'r', 'e', 't', + } + + msg := NewConnectMessage() + n, err := msg.Decode(msgBytes) + + assert.NoError(t, true, err, "Error decoding message.") + assert.Equal(t, true, len(msgBytes), n, "Error decoding message.") + assert.Equal(t, true, 206, msg.connectFlags, "Incorrect flag value.") + assert.Equal(t, true, 10, msg.KeepAlive(), "Incorrect KeepAlive value.") + assert.Equal(t, true, "surgemq", string(msg.ClientId()), "Incorrect client ID value.") + assert.Equal(t, true, "will", string(msg.WillTopic()), "Incorrect will topic value.") + assert.Equal(t, true, "send me home", string(msg.WillMessage()), "Incorrect will message value.") + assert.Equal(t, true, "surgemq", string(msg.Username()), "Incorrect username value.") + assert.Equal(t, true, "verysecret", string(msg.Password()), "Incorrect password value.") +} + +func TestConnectMessageDecode2(t *testing.T) { + // missing last byte 't' + msgBytes := []byte{ + byte(CONNECT << 4), + 60, + 0, // Length MSB (0) + 4, // Length LSB (4) + 'M', 'Q', 'T', 'T', + 4, // Protocol level 4 + 206, // connect flags 11001110, will QoS = 01 + 0, // Keep Alive MSB (0) + 10, // Keep Alive LSB (10) + 0, // Client ID MSB (0) + 7, // Client ID LSB (7) + 's', 'u', 'r', 'g', 'e', 'm', 'q', + 0, // Will Topic MSB (0) + 4, // Will Topic LSB (4) + 'w', 'i', 'l', 'l', + 0, // Will Message MSB (0) + 12, // Will Message LSB (12) + 's', 'e', 'n', 'd', ' ', 'm', 'e', ' ', 'h', 'o', 'm', 'e', + 0, // Username ID MSB (0) + 7, // Username ID LSB (7) + 's', 'u', 'r', 'g', 'e', 'm', 'q', + 0, // Password ID MSB (0) + 10, // Password ID LSB (10) + 'v', 'e', 'r', 'y', 's', 'e', 'c', 'r', 'e', + } + + msg := NewConnectMessage() + _, err := msg.Decode(msgBytes) + + assert.Error(t, true, err) +} + +func TestConnectMessageDecode3(t *testing.T) { + // extra bytes + msgBytes := []byte{ + byte(CONNECT << 4), + 60, + 0, // Length MSB (0) + 4, // Length LSB (4) + 'M', 'Q', 'T', 'T', + 4, // Protocol level 4 + 206, // connect flags 11001110, will QoS = 01 + 0, // Keep Alive MSB (0) + 10, // Keep Alive LSB (10) + 0, // Client ID MSB (0) + 7, // Client ID LSB (7) + 's', 'u', 'r', 'g', 'e', 'm', 'q', + 0, // Will Topic MSB (0) + 4, // Will Topic LSB (4) + 'w', 'i', 'l', 'l', + 0, // Will Message MSB (0) + 12, // Will Message LSB (12) + 's', 'e', 'n', 'd', ' ', 'm', 'e', ' ', 'h', 'o', 'm', 'e', + 0, // Username ID MSB (0) + 7, // Username ID LSB (7) + 's', 'u', 'r', 'g', 'e', 'm', 'q', + 0, // Password ID MSB (0) + 10, // Password ID LSB (10) + 'v', 'e', 'r', 'y', 's', 'e', 'c', 'r', 'e', 't', + 'e', 'x', 't', 'r', 'a', + } + + msg := NewConnectMessage() + n, err := msg.Decode(msgBytes) + + assert.NoError(t, true, err) + assert.Equal(t, true, 62, n) +} + +func TestConnectMessageDecode4(t *testing.T) { + // missing client Id, clean session == 0 + msgBytes := []byte{ + byte(CONNECT << 4), + 53, + 0, // Length MSB (0) + 4, // Length LSB (4) + 'M', 'Q', 'T', 'T', + 4, // Protocol level 4 + 204, // connect flags 11001110, will QoS = 01 + 0, // Keep Alive MSB (0) + 10, // Keep Alive LSB (10) + 0, // Client ID MSB (0) + 0, // Client ID LSB (0) + 0, // Will Topic MSB (0) + 4, // Will Topic LSB (4) + 'w', 'i', 'l', 'l', + 0, // Will Message MSB (0) + 12, // Will Message LSB (12) + 's', 'e', 'n', 'd', ' ', 'm', 'e', ' ', 'h', 'o', 'm', 'e', + 0, // Username ID MSB (0) + 7, // Username ID LSB (7) + 's', 'u', 'r', 'g', 'e', 'm', 'q', + 0, // Password ID MSB (0) + 10, // Password ID LSB (10) + 'v', 'e', 'r', 'y', 's', 'e', 'c', 'r', 'e', 't', + } + + msg := NewConnectMessage() + _, err := msg.Decode(msgBytes) + + assert.Error(t, true, err) +} + +func TestConnectMessageEncode(t *testing.T) { + msgBytes := []byte{ + byte(CONNECT << 4), + 60, + 0, // Length MSB (0) + 4, // Length LSB (4) + 'M', 'Q', 'T', 'T', + 4, // Protocol level 4 + 206, // connect flags 11001110, will QoS = 01 + 0, // Keep Alive MSB (0) + 10, // Keep Alive LSB (10) + 0, // Client ID MSB (0) + 7, // Client ID LSB (7) + 's', 'u', 'r', 'g', 'e', 'm', 'q', + 0, // Will Topic MSB (0) + 4, // Will Topic LSB (4) + 'w', 'i', 'l', 'l', + 0, // Will Message MSB (0) + 12, // Will Message LSB (12) + 's', 'e', 'n', 'd', ' ', 'm', 'e', ' ', 'h', 'o', 'm', 'e', + 0, // Username ID MSB (0) + 7, // Username ID LSB (7) + 's', 'u', 'r', 'g', 'e', 'm', 'q', + 0, // Password ID MSB (0) + 10, // Password ID LSB (10) + 'v', 'e', 'r', 'y', 's', 'e', 'c', 'r', 'e', 't', + } + + msg := NewConnectMessage() + msg.SetWillQos(1) + msg.SetVersion(4) + msg.SetCleanSession(true) + msg.SetClientId([]byte("surgemq")) + msg.SetKeepAlive(10) + msg.SetWillTopic([]byte("will")) + msg.SetWillMessage([]byte("send me home")) + msg.SetUsername([]byte("surgemq")) + msg.SetPassword([]byte("verysecret")) + + dst := make([]byte, 100) + n, err := msg.Encode(dst) + + assert.NoError(t, true, err, "Error decoding message.") + assert.Equal(t, true, len(msgBytes), n, "Error decoding message.") + assert.Equal(t, true, msgBytes, dst[:n], "Error decoding message.") +} + +// test to ensure encoding and decoding are the same +// decode, encode, and decode again +func TestConnectDecodeEncodeEquiv(t *testing.T) { + msgBytes := []byte{ + byte(CONNECT << 4), + 60, + 0, // Length MSB (0) + 4, // Length LSB (4) + 'M', 'Q', 'T', 'T', + 4, // Protocol level 4 + 206, // connect flags 11001110, will QoS = 01 + 0, // Keep Alive MSB (0) + 10, // Keep Alive LSB (10) + 0, // Client ID MSB (0) + 7, // Client ID LSB (7) + 's', 'u', 'r', 'g', 'e', 'm', 'q', + 0, // Will Topic MSB (0) + 4, // Will Topic LSB (4) + 'w', 'i', 'l', 'l', + 0, // Will Message MSB (0) + 12, // Will Message LSB (12) + 's', 'e', 'n', 'd', ' ', 'm', 'e', ' ', 'h', 'o', 'm', 'e', + 0, // Username ID MSB (0) + 7, // Username ID LSB (7) + 's', 'u', 'r', 'g', 'e', 'm', 'q', + 0, // Password ID MSB (0) + 10, // Password ID LSB (10) + 'v', 'e', 'r', 'y', 's', 'e', 'c', 'r', 'e', 't', + } + + msg := NewConnectMessage() + n, err := msg.Decode(msgBytes) + + assert.NoError(t, true, err, "Error decoding message.") + assert.Equal(t, true, len(msgBytes), n, "Error decoding message.") + + dst := make([]byte, 100) + n2, err := msg.Encode(dst) + + assert.NoError(t, true, err, "Error decoding message.") + assert.Equal(t, true, len(msgBytes), n2, "Error decoding message.") + assert.Equal(t, true, msgBytes, dst[:n2], "Error decoding message.") + + n3, err := msg.Decode(dst) + + assert.NoError(t, true, err, "Error decoding message.") + assert.Equal(t, true, len(msgBytes), n3, "Error decoding message.") +} diff --git a/disconnect.go b/disconnect.go new file mode 100644 index 0000000..4d7af75 --- /dev/null +++ b/disconnect.go @@ -0,0 +1,39 @@ +// Copyright (c) 2014 Dataence, LLC. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package message + +// The DISCONNECT Packet is the final Control Packet sent from the Client to the Server. +// It indicates that the Client is disconnecting cleanly. +type DisconnectMessage struct { + header +} + +var _ Message = (*DisconnectMessage)(nil) + +// NewDisconnectMessage creates a new DISCONNECT message. +func NewDisconnectMessage() *DisconnectMessage { + msg := &DisconnectMessage{} + msg.SetType(DISCONNECT) + + return msg +} + +func (this *DisconnectMessage) Decode(src []byte) (int, error) { + return this.header.decode(src) +} + +func (this *DisconnectMessage) Encode(dst []byte) (int, error) { + return this.header.encode(dst) +} diff --git a/disconnect_test.go b/disconnect_test.go new file mode 100644 index 0000000..eeb68dc --- /dev/null +++ b/disconnect_test.go @@ -0,0 +1,78 @@ +// Copyright (c) 2014 Dataence, LLC. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package message + +import ( + "testing" + + "github.com/dataence/assert" +) + +func TestDisconnectMessageDecode(t *testing.T) { + msgBytes := []byte{ + byte(DISCONNECT << 4), + 0, + } + + msg := NewDisconnectMessage() + n, err := msg.Decode(msgBytes) + + assert.NoError(t, true, err, "Error decoding message.") + assert.Equal(t, true, len(msgBytes), n, "Error decoding message.") + assert.Equal(t, true, DISCONNECT, msg.Type(), "Error decoding message.") +} + +func TestDisconnectMessageEncode(t *testing.T) { + msgBytes := []byte{ + byte(DISCONNECT << 4), + 0, + } + + msg := NewDisconnectMessage() + + dst := make([]byte, 10) + n, err := msg.Encode(dst) + + assert.NoError(t, true, err, "Error decoding message.") + assert.Equal(t, true, len(msgBytes), n, "Error decoding message.") + assert.Equal(t, true, msgBytes, dst[:n], "Error decoding message.") +} + +// test to ensure encoding and decoding are the same +// decode, encode, and decode again +func TestDisconnectDecodeEncodeEquiv(t *testing.T) { + msgBytes := []byte{ + byte(DISCONNECT << 4), + 0, + } + + msg := NewDisconnectMessage() + n, err := msg.Decode(msgBytes) + + assert.NoError(t, true, err, "Error decoding message.") + assert.Equal(t, true, len(msgBytes), n, "Error decoding message.") + + dst := make([]byte, 100) + n2, err := msg.Encode(dst) + + assert.NoError(t, true, err, "Error decoding message.") + assert.Equal(t, true, len(msgBytes), n2, "Error decoding message.") + assert.Equal(t, true, msgBytes, dst[:n2], "Error decoding message.") + + n3, err := msg.Decode(dst) + + assert.NoError(t, true, err, "Error decoding message.") + assert.Equal(t, true, len(msgBytes), n3, "Error decoding message.") +} diff --git a/doc.go b/doc.go new file mode 100644 index 0000000..5f18fd8 --- /dev/null +++ b/doc.go @@ -0,0 +1,141 @@ +// Copyright (c) 2014 Dataence, LLC. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +/* +Package mqtt is an encoder/decoder library for MQTT 3.1 and 3.1.1 messages. You can +find the MQTT specs at the following locations: + + 3.1.1 - http://docs.oasis-open.org/mqtt/mqtt/v3.1.1/ + 3.1 - http://public.dhe.ibm.com/software/dw/webservices/ws-mqtt/mqtt-v3r1.html + +From the spec: + + MQTT is a Client Server publish/subscribe messaging transport protocol. It is + light weight, open, simple, and designed so as to be easy to implement. These + characteristics make it ideal for use in many situations, including constrained + environments such as for communication in Machine to Machine (M2M) and Internet + of Things (IoT) contexts where a small code footprint is required and/or network + bandwidth is at a premium. + + The MQTT protocol works by exchanging a series of MQTT messages in a defined way. + The protocol runs over TCP/IP, or over other network protocols that provide + ordered, lossless, bi-directional connections. + + +There are two main items to take note in this package. The first is + + type MessageType byte + +MessageType is the type representing the MQTT packet types. In the MQTT spec, MQTT +control packet type is represented as a 4-bit unsigned value. MessageType receives +several methods that returns string representations of the names and descriptions. + +Also, one of the methods is New(). It returns a new Message object based on the mtype +parameter. For example: + + m, err := CONNECT.New() + msg := m.(*ConnectMessage) + +This would return a PublishMessage struct, but mapped to the Message interface. You can +then type assert it back to a *PublishMessage. Another way to create a new +PublishMessage is to call + + msg := NewConnectMessage() + +Every message type has a New function that returns a new message. The list of available +message types are defined as constants below. + +As you may have noticed, the second important item is the Message interface. It defines +several methods that are common to all messages, including Name(), Desc(), and Type(). +Most importantly, it also defines the Encode() and Decode() methods. + + Encode() (io.Reader, int, error) + Decode(io.Reader) (int, error) + +Encode returns an io.Reader in which the encoded bytes can be read. The second return +value is the number of bytes encoded, so the caller knows how many bytes there will be. +If Encode returns an error, then the first two return values should be considered invalid. +Any changes to the message after Encode() is called will invalidate the io.Reader. + +Decode reads from the io.Reader parameter until a full message is decoded, or when io.Reader +returns EOF or error. The first return value is the number of bytes read from io.Reader. +The second is error if Decode encounters any problems. + +With these in mind, we can now do: + + // Create a new CONNECT message + msg := NewConnectMessage() + + // Set the appropriate parameters + msg.SetWillQos(1) + msg.SetVersion(4) + msg.SetCleanSession(true) + msg.SetClientId([]byte("surgemq")) + msg.SetKeepAlive(10) + msg.SetWillTopic([]byte("will")) + msg.SetWillMessage([]byte("send me home")) + msg.SetUsername([]byte("surgemq")) + msg.SetPassword([]byte("verysecret")) + + // Encode the message and get the io.Reader + r, n, err := msg.Encode() + if err == nil { + return err + } + + // Write n bytes into the connection + m, err := io.CopyN(conn, r, int64(n)) + if err != nil { + return err + } + + fmt.Printf("Sent %d bytes of %s message", m, msg.Name()) + +To receive a CONNECT message from a connection, we can do: + + // Create a new CONNECT message + msg := NewConnectMessage() + + // Decode the message by reading from conn + n, err := msg.Decode(conn) + +If you don't know what type of message is coming down the pipe, you can do something like this: + + // Create a buffered IO reader for the connection + br := bufio.NewReader(conn) + + // Peek at the first byte, which contains the message type + b, err := br.Peek(1) + if err != nil { + return err + } + + // Extract the type from the first byte + t := MessageType(b[0] >> 4) + + // Create a new message + msg, err := t.New() + if err != nil { + return err + } + + // Decode it from the bufio.Reader + n, err := msg.Decode(br) + if err != nil { + return err + } + + +*/ +package message diff --git a/header.go b/header.go new file mode 100644 index 0000000..fb72b7f --- /dev/null +++ b/header.go @@ -0,0 +1,194 @@ +// Copyright (c) 2014 Dataence, LLC. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package message + +import ( + "encoding/binary" + "fmt" +) + +var ( + gPacketId uint64 = 0 +) + +// Fixed header +// - 1 byte for control packet type (bits 7-4) and flags (bits 3-0) +// - up to 4 byte for remaining length +type header struct { + remlen int32 + mtype MessageType + flags byte + packetId uint16 +} + +// String returns a string representation of the message. +func (this header) String() string { + return fmt.Sprintf("Packet type: %s\nFlags: %08b\nRemaining Length: %d bytes\n", this.mtype.Name(), this.flags, this.remlen) +} + +// Name returns a string representation of the message type. Examples include +// "PUBLISH", "SUBSCRIBE", and others. This is statically defined for each of +// the message types and cannot be changed. +func (this *header) Name() string { + return this.Type().Name() +} + +// Desc returns a string description of the message type. For example, a +// CONNECT message would return "Client request to connect to Server." These +// descriptions are statically defined (copied from the MQTT spec) and cannot +// be changed. +func (this *header) Desc() string { + return this.Type().Desc() +} + +// Type returns the MessageType of the Message. The retured value should be one +// of the constants defined for MessageType. +func (this *header) Type() MessageType { + return this.mtype +} + +// SetType sets the message type of this message. It also correctly sets the +// default flags for the message type. It returns an error if the type is invalid. +func (this *header) SetType(mtype MessageType) error { + if !mtype.Valid() { + return fmt.Errorf("header/SetType: Invalid control packet type %d", mtype) + } + + this.mtype = mtype + + this.flags = mtype.DefaultFlags() + + return nil +} + +// Flags returns the fixed header flags for this message. +func (this *header) Flags() byte { + return this.flags +} + +// RemainingLength returns the length of the non-fixed-header part of the message. +func (this *header) RemainingLength() int32 { + return this.remlen +} + +// SetRemainingLength sets the length of the non-fixed-header part of the message. +// It returns error if the length is greater than 268435455, which is the max +// message length as defined by the MQTT spec. +func (this *header) SetRemainingLength(remlen int32) error { + if remlen > maxRemainingLength || remlen < 0 { + return fmt.Errorf("header/SetLength: Remaining length (%d) out of bound (max %d, min 0)", remlen, maxRemainingLength) + } + + this.remlen = remlen + return nil +} + +func (this *header) Len() int { + return this.msglen() +} + +// PacketId returns the ID of the packet. +func (this *header) PacketId() uint16 { + return this.packetId +} + +// SetPacketId sets the ID of the packet. +func (this *header) SetPacketId(v uint16) { + this.packetId = v +} + +func (this *header) encode(dst []byte) (int, error) { + ml := this.msglen() + + if len(dst) < ml { + return 0, fmt.Errorf("header/Encode: Insufficient buffer size. Expecting %d, got %d.", ml, len(dst)) + } + + total := 0 + + if this.remlen > maxRemainingLength || this.remlen < 0 { + return total, fmt.Errorf("header/Encode: Remaining length (%d) out of bound (max %d, min 0)", this.remlen, maxRemainingLength) + } + + if !this.mtype.Valid() { + return total, fmt.Errorf("header/Encode: Invalid message type %d", this.mtype) + } + + dst[total] = byte(this.mtype)<<4 | this.flags + total += 1 + + n := binary.PutUvarint(dst[total:], uint64(this.remlen)) + total += n + + return total, nil +} + +// Decode reads from the io.Reader parameter until a full message is decoded, or +// when io.Reader returns EOF or error. The first return value is the number of +// bytes read from io.Reader. The second is error if Decode encounters any problems. +func (this *header) decode(src []byte) (int, error) { + total := 0 + + mtype := MessageType(src[total] >> 4) + if !mtype.Valid() { + return total, fmt.Errorf("header/Decode: Invalid message type %d.", mtype) + } + + if mtype != this.mtype { + return total, fmt.Errorf("header/Decode: Invalid message type %d. Expecting %d.", mtype, this.mtype) + } + + this.flags = src[total] & 0x0f + if this.mtype != PUBLISH && this.flags != this.mtype.DefaultFlags() { + return total, fmt.Errorf("header/Decode: Invalid message (%d) flags. Expecting %d, got %d", this.mtype, this.mtype.DefaultFlags, this.flags) + } + + if this.mtype == PUBLISH && !ValidQos((this.flags>>1)&0x3) { + return total, fmt.Errorf("header/Decode: Invalid QoS (%d) for PUBLISH message.", (this.flags>>1)&0x3) + } + + total++ + + remlen, m := binary.Uvarint(src[total:]) + total += m + this.remlen = int32(remlen) + + if this.remlen > maxRemainingLength || remlen < 0 { + return total, fmt.Errorf("header/Decode: Remaining length (%d) out of bound (max %d, min 0)", this.remlen, maxRemainingLength) + } + + if int(this.remlen) > len(src[total:]) { + return total, fmt.Errorf("header/Decode: Remaining length (%d) is greater than remaining buffer (%d)", this.remlen, len(src[total:])) + } + + return total, nil +} + +func (this *header) msglen() int { + // message type and flag byte + total := 1 + + if this.remlen <= 127 { + total += 1 + } else if this.remlen <= 16383 { + total += 2 + } else if this.remlen <= 2097151 { + total += 3 + } else { + total += 4 + } + + return total +} diff --git a/header_test.go b/header_test.go new file mode 100644 index 0000000..3f99d3d --- /dev/null +++ b/header_test.go @@ -0,0 +1,178 @@ +// Copyright (c) 2014 Dataence, LLC. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package message + +import ( + "testing" + + "github.com/dataence/assert" +) + +func TestMessageHeaderFields(t *testing.T) { + header := &header{} + + header.SetRemainingLength(33) + + assert.Equal(t, true, 33, header.RemainingLength()) + + err := header.SetRemainingLength(268435456) + + assert.Error(t, true, err) + + err = header.SetRemainingLength(-1) + + assert.Error(t, true, err) + + err = header.SetType(RESERVED) + + assert.Error(t, true, err) + + err = header.SetType(PUBREL) + + assert.NoError(t, true, err) + assert.Equal(t, true, PUBREL, header.Type()) + assert.Equal(t, true, "PUBREL", header.Name()) + assert.Equal(t, true, 2, header.Flags()) +} + +// Not enough bytes +func TestMessageHeaderDecode(t *testing.T) { + buf := []byte{0x6f, 193, 2} + header := &header{} + + _, err := header.decode(buf) + assert.Error(t, true, err) +} + +// Remaining length too big +func TestMessageHeaderDecode2(t *testing.T) { + buf := []byte{0x62, 0xff, 0xff, 0xff, 0xff} + header := &header{} + + _, err := header.decode(buf) + assert.Error(t, true, err) +} + +func TestMessageHeaderDecode3(t *testing.T) { + buf := []byte{0x62, 0xff} + header := &header{} + + _, err := header.decode(buf) + assert.Error(t, true, err) +} + +func TestMessageHeaderDecode4(t *testing.T) { + buf := []byte{0x62, 0xff, 0xff, 0xff, 0x7f} + header := &header{ + mtype: 6, + flags: 2, + } + + n, err := header.decode(buf) + + assert.Error(t, true, err) + assert.Equal(t, true, 5, n) + assert.Equal(t, true, maxRemainingLength, header.RemainingLength()) +} + +func TestMessageHeaderDecode5(t *testing.T) { + buf := []byte{0x62, 0xff, 0x7f} + header := &header{ + mtype: 6, + flags: 2, + } + + n, err := header.decode(buf) + assert.Error(t, true, err) + assert.Equal(t, true, 3, n) +} + +func TestMessageHeaderEncode1(t *testing.T) { + header := &header{} + headerBytes := []byte{0x62, 193, 2} + + err := header.SetType(PUBREL) + + assert.NoError(t, true, err) + + err = header.SetRemainingLength(321) + + assert.NoError(t, true, err) + + buf := make([]byte, 3) + n, err := header.encode(buf) + + assert.NoError(t, true, err) + assert.Equal(t, true, 3, n) + assert.Equal(t, true, headerBytes, buf) +} + +func TestMessageHeaderEncode2(t *testing.T) { + header := &header{} + + err := header.SetType(PUBREL) + assert.NoError(t, true, err) + + header.remlen = 268435456 + + buf := make([]byte, 5) + _, err = header.encode(buf) + + assert.Error(t, true, err) +} + +func TestMessageHeaderEncode3(t *testing.T) { + header := &header{} + headerBytes := []byte{0x62, 0xff, 0xff, 0xff, 0x7f} + + err := header.SetType(PUBREL) + + assert.NoError(t, true, err) + + err = header.SetRemainingLength(maxRemainingLength) + + assert.NoError(t, true, err) + + buf := make([]byte, 5) + n, err := header.encode(buf) + + assert.NoError(t, true, err) + assert.Equal(t, true, 5, n) + assert.Equal(t, true, headerBytes, buf) +} + +func TestMessageHeaderEncode4(t *testing.T) { + header := &header{} + + header.mtype = RESERVED2 + + buf := make([]byte, 5) + _, err := header.encode(buf) + assert.Error(t, true, err) +} + +/* +// This test is to ensure that an empty message is at least 2 bytes long +func TestMessageHeaderEncode5(t *testing.T) { + msg := NewPingreqMessage() + + dst, n, err := msg.encode() + if err != nil { + t.Errorf("Error encoding PINGREQ message: %v", err) + } else if n != 2 { + t.Errorf("Incorrect result. Expecting length of 2 bytes, got %d.", dst.(*bytes.Buffer).Len()) + } +} +*/ diff --git a/message.go b/message.go new file mode 100644 index 0000000..5722180 --- /dev/null +++ b/message.go @@ -0,0 +1,357 @@ +// Copyright (c) 2014 Dataence, LLC. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package message + +import ( + "bytes" + "fmt" + "regexp" +) + +var clientIdRegexp *regexp.Regexp + +func init() { + clientIdRegexp, _ = regexp.Compile("^[0-9a-zA-Z]*$") +} + +const ( + maxLPString uint16 = 65535 + maxFixedHeaderLength int = 5 + maxRemainingLength int32 = 268435455 // bytes, or 256 MB +) + +const ( + // QoS 0: At most once delivery + // The message is delivered according to the capabilities of the underlying network. + // No response is sent by the receiver and no retry is performed by the sender. The + // message arrives at the receiver either once or not at all. + QosAtMostOnce byte = iota + + // QoS 1: At least once delivery + // This quality of service ensures that the message arrives at the receiver at least once. + // A QoS 1 PUBLISH Packet has a Packet Identifier in its variable header and is acknowledged + // by a PUBACK Packet. Section 2.3.1 provides more information about Packet Identifiers. + QosAtLeastOnce + + // QoS 2: Exactly once delivery + // This is the highest quality of service, for use when neither loss nor duplication of + // messages are acceptable. There is an increased overhead associated with this quality of + // service. + QosExactlyOnce + + // QosFailure is a return value for a subscription if there's a problem while subscribing + // to a specific topic. + QosFailure = 0x80 +) + +// SupportedVersions is a map of the version number (0x3 or 0x4) to the version string, +// "MQIsdp" for 0x3, and "MQTT" for 0x4. +var SupportedVersions map[byte]string = map[byte]string{ + 0x3: "MQIsdp", + 0x4: "MQTT", +} + +// MessageType is the type representing the MQTT packet types. In the MQTT spec, +// MQTT control packet type is represented as a 4-bit unsigned value. +type MessageType byte + +// Message is an interface defined for all MQTT message types. +type Message interface { + // Name returns a string representation of the message type. Examples include + // "PUBLISH", "SUBSCRIBE", and others. This is statically defined for each of + // the message types and cannot be changed. + Name() string + + // Desc returns a string description of the message type. For example, a + // CONNECT message would return "Client request to connect to Server." These + // descriptions are statically defined (copied from the MQTT spec) and cannot + // be changed. + Desc() string + + // Type returns the MessageType of the Message. The retured value should be one + // of the constants defined for MessageType. + Type() MessageType + + // Encode writes the message bytes into the byte array from the argument. It + // returns the number of bytes encoded and whether there's any errors along + // the way. If there's any errors, then the byte slice and count should be + // considered invalid. + Encode([]byte) (int, error) + + // Decode reads the bytes in the byte slice from the argument. It returns the + // total number of bytes decoded, and whether there's any errors during the + // process. The byte slice MUST NOT be modified during the duration of this + // message being available since the byte slice is internally stored for + // references. + Decode([]byte) (int, error) + + Len() int +} + +const ( + // RESERVED is a reserved value and should be considered an invalid message type + RESERVED MessageType = iota + + // CONNECT: Client to Server. Client request to connect to Server. + CONNECT + + // CONNACK: Server to Client. Connect acknowledgement. + CONNACK + + // PUBLISH: Client to Server, or Server to Client. Publish message. + PUBLISH + + // PUBACK: Client to Server, or Server to Client. Publish acknowledgment for + // QoS 1 messages. + PUBACK + + // PUBACK: Client to Server, or Server to Client. Publish received for QoS 2 messages. + // Assured delivery part 1. + PUBREC + + // PUBREL: Client to Server, or Server to Client. Publish release for QoS 2 messages. + // Assured delivery part 1. + PUBREL + + // PUBCOMP: Client to Server, or Server to Client. Publish complete for QoS 2 messages. + // Assured delivery part 3. + PUBCOMP + + // SUBSCRIBE: Client to Server. Client subscribe request. + SUBSCRIBE + + // SUBACK: Server to Client. Subscribe acknowledgement. + SUBACK + + // UNSUBSCRIBE: Client to Server. Unsubscribe request. + UNSUBSCRIBE + + // UNSUBACK: Server to Client. Unsubscribe acknowlegment. + UNSUBACK + + // PINGREQ: Client to Server. PING request. + PINGREQ + + // PINGRESP: Server to Client. PING response. + PINGRESP + + // DISCONNECT: Client to Server. Client is disconnecting. + DISCONNECT + + // RESERVED2 is a reserved value and should be considered an invalid message type. + RESERVED2 +) + +// Name returns the name of the message type. It should correspond to one of the +// constant values defined for MessageType. It is statically defined and cannot +// be changed. +func (this MessageType) Name() string { + switch this { + case RESERVED: + return "RESERVED" + case CONNECT: + return "CONNECT" + case CONNACK: + return "CONNACK" + case PUBLISH: + return "PUBLISH" + case PUBACK: + return "PUBACK" + case PUBREC: + return "PUBREC" + case PUBREL: + return "PUBREL" + case PUBCOMP: + return "PUBCOMP" + case SUBSCRIBE: + return "SUBSCRIBE" + case SUBACK: + return "SUBACK" + case UNSUBSCRIBE: + return "UNSUBSCRIBE" + case UNSUBACK: + return "UNSUBACK" + case PINGREQ: + return "PINGREQ" + case PINGRESP: + return "PINGRESP" + case DISCONNECT: + return "DISCONNECT" + case RESERVED2: + return "RESERVED2" + } + + return "UNKNOWN" +} + +// Desc returns the description of the message type. It is statically defined (copied +// from MQTT spec) and cannot be changed. +func (this MessageType) Desc() string { + switch this { + case RESERVED: + return "Reserved" + case CONNECT: + return "Client request to connect to Server" + case CONNACK: + return "Connect acknowledgement" + case PUBLISH: + return "Publish message" + case PUBACK: + return "Publish acknowledgement" + case PUBREC: + return "Publish received (assured delivery part 1)" + case PUBREL: + return "Publish release (assured delivery part 2)" + case PUBCOMP: + return "Publish complete (assured delivery part 3)" + case SUBSCRIBE: + return "Client subscribe request" + case SUBACK: + return "Subscribe acknowledgement" + case UNSUBSCRIBE: + return "Unsubscribe request" + case UNSUBACK: + return "Unsubscribe acknowledgement" + case PINGREQ: + return "PING request" + case PINGRESP: + return "PING response" + case DISCONNECT: + return "Client is disconnecting" + case RESERVED2: + return "Reserved" + } + + return "UNKNOWN" +} + +// DefaultFlags returns the default flag values for the message type, as defined by +// the MQTT spec. +func (this MessageType) DefaultFlags() byte { + switch this { + case RESERVED: + return 0 + case CONNECT: + return 0 + case CONNACK: + return 0 + case PUBLISH: + return 0 + case PUBACK: + return 0 + case PUBREC: + return 0 + case PUBREL: + return 2 + case PUBCOMP: + return 0 + case SUBSCRIBE: + return 2 + case SUBACK: + return 0 + case UNSUBSCRIBE: + return 2 + case UNSUBACK: + return 0 + case PINGREQ: + return 0 + case PINGRESP: + return 0 + case DISCONNECT: + return 0 + case RESERVED2: + return 0 + } + + return 0 +} + +// New creates a new message based on the message type. It is a shortcut to call +// one of the New*Message functions. If an error is returned then the message type +// is invalid. +func (this MessageType) New() (Message, error) { + switch this { + case CONNECT: + return NewConnectMessage(), nil + case CONNACK: + return NewConnackMessage(), nil + case PUBLISH: + return NewPublishMessage(), nil + case PUBACK: + return NewPubackMessage(), nil + case PUBREC: + return NewPubrecMessage(), nil + case PUBREL: + return NewPubrelMessage(), nil + case PUBCOMP: + return NewPubcompMessage(), nil + case SUBSCRIBE: + return NewSubscribeMessage(), nil + case SUBACK: + return NewSubackMessage(), nil + case UNSUBSCRIBE: + return NewUnsubscribeMessage(), nil + case UNSUBACK: + return NewUnsubackMessage(), nil + case PINGREQ: + return NewPingreqMessage(), nil + case PINGRESP: + return NewPingrespMessage(), nil + case DISCONNECT: + return NewDisconnectMessage(), nil + } + + return nil, fmt.Errorf("msgtype/NewMessage: Invalid message type %d", this) +} + +// Valid returns a boolean indicating whether the message type is valid or not. +func (this MessageType) Valid() bool { + return this > RESERVED && this < RESERVED2 +} + +// ValidTopic checks the topic, which is a slice of bytes, to see if it's valid. Topic is +// considered valid if it's longer than 0 bytes, and doesn't contain any wildcard characters +// such as + and #. +func ValidTopic(topic []byte) bool { + return len(topic) > 0 && bytes.IndexByte(topic, '#') == -1 && bytes.IndexByte(topic, '+') == -1 +} + +// ValidQos checks the QoS value to see if it's valid. Valid QoS are QosAtMostOnce, +// QosAtLeastonce, and QosExactlyOnce. +func ValidQos(qos byte) bool { + return qos == QosAtMostOnce || qos == QosAtLeastOnce || qos == QosExactlyOnce +} + +// ValidClientId checks the client ID, which is a slice of bytes, to see if it's valid. +// Client ID is valid if it meets the requirement from the MQTT spec: +// The Server MUST allow ClientIds which are between 1 and 23 UTF-8 encoded bytes in length, +// and that contain only the characters +// +// "0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ" +func ValidClientId(cid []byte) bool { + return clientIdRegexp.Match(cid) +} + +// ValidVersion checks to see if the version is valid. Current supported versions include 0x3 and 0x4. +func ValidVersion(v byte) bool { + _, ok := SupportedVersions[v] + return ok +} + +// ValidConnackError checks to see if the error is a Connack Error or not +func ValidConnackError(err error) bool { + return err == ErrInvalidProtocolVersion || err == ErrIdentifierRejected || + err == ErrServerUnavailable || err == ErrBadUsernameOrPassword || err == ErrNotAuthorized +} diff --git a/message_test.go b/message_test.go new file mode 100644 index 0000000..fc21b08 --- /dev/null +++ b/message_test.go @@ -0,0 +1,178 @@ +// Copyright (c) 2014 Dataence, LLC. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package message + +import ( + "testing" + + "github.com/dataence/assert" +) + +var ( + lpstrings []string = []string{ + "this is a test", + "hope it succeeds", + "but just in case", + "send me your millions", + "", + } + + lpstringBytes []byte = []byte{ + 0x0, 0xe, 't', 'h', 'i', 's', ' ', 'i', 's', ' ', 'a', ' ', 't', 'e', 's', 't', + 0x0, 0x10, 'h', 'o', 'p', 'e', ' ', 'i', 't', ' ', 's', 'u', 'c', 'c', 'e', 'e', 'd', 's', + 0x0, 0x10, 'b', 'u', 't', ' ', 'j', 'u', 's', 't', ' ', 'i', 'n', ' ', 'c', 'a', 's', 'e', + 0x0, 0x15, 's', 'e', 'n', 'd', ' ', 'm', 'e', ' ', 'y', 'o', 'u', 'r', ' ', 'm', 'i', 'l', 'l', 'i', 'o', 'n', 's', + 0x0, 0x0, + } + + msgBytes []byte = []byte{ + byte(CONNECT << 4), + 60, + 0, // Length MSB (0) + 4, // Length LSB (4) + 'M', 'Q', 'T', 'T', + 4, // Protocol level 4 + 206, // connect flags 11001110, will QoS = 01 + 0, // Keep Alive MSB (0) + 10, // Keep Alive LSB (10) + 0, // Client ID MSB (0) + 7, // Client ID LSB (7) + 's', 'u', 'r', 'g', 'e', 'm', 'q', + 0, // Will Topic MSB (0) + 4, // Will Topic LSB (4) + 'w', 'i', 'l', 'l', + 0, // Will Message MSB (0) + 12, // Will Message LSB (12) + 's', 'e', 'n', 'd', ' ', 'm', 'e', ' ', 'h', 'o', 'm', 'e', + 0, // Username ID MSB (0) + 7, // Username ID LSB (7) + 's', 'u', 'r', 'g', 'e', 'm', 'q', + 0, // Password ID MSB (0) + 10, // Password ID LSB (10) + 'v', 'e', 'r', 'y', 's', 'e', 'c', 'r', 'e', 't', + } +) + +func TestReadLPBytes(t *testing.T) { + total := 0 + + for _, str := range lpstrings { + b, n, err := readLPBytes(lpstringBytes[total:]) + + assert.NoError(t, true, err) + assert.Equal(t, true, str, string(b)) + assert.Equal(t, true, len(str)+2, n) + + total += n + } +} + +func TestWriteLPBytes(t *testing.T) { + total := 0 + buf := make([]byte, 1000) + + for _, str := range lpstrings { + n, err := writeLPBytes(buf[total:], []byte(str)) + + assert.NoError(t, true, err) + assert.Equal(t, true, 2+len(str), n) + + total += n + } + + assert.Equal(t, true, lpstringBytes, buf[:total]) +} + +func TestMessageTypes(t *testing.T) { + if CONNECT != 1 || + CONNACK != 2 || + PUBLISH != 3 || + PUBACK != 4 || + PUBREC != 5 || + PUBREL != 6 || + PUBCOMP != 7 || + SUBSCRIBE != 8 || + SUBACK != 9 || + UNSUBSCRIBE != 10 || + UNSUBACK != 11 || + PINGREQ != 12 || + PINGRESP != 13 || + DISCONNECT != 14 { + + t.Errorf("Message types have invalid code") + } +} + +func TestQosCodes(t *testing.T) { + if QosAtMostOnce != 0 || QosAtLeastOnce != 1 || QosExactlyOnce != 2 { + t.Errorf("QOS codes invalid") + } +} + +func TestConnackReturnCodes(t *testing.T) { + assert.Equal(t, false, ErrInvalidProtocolVersion.Error(), ConnackCode(1).Error(), "Incorrect ConnackCode error value.") + + assert.Equal(t, false, ErrIdentifierRejected.Error(), ConnackCode(2).Error(), "Incorrect ConnackCode error value.") + + assert.Equal(t, false, ErrServerUnavailable.Error(), ConnackCode(3).Error(), "Incorrect ConnackCode error value.") + + assert.Equal(t, false, ErrBadUsernameOrPassword.Error(), ConnackCode(4).Error(), "Incorrect ConnackCode error value.") + + assert.Equal(t, false, ErrNotAuthorized.Error(), ConnackCode(5).Error(), "Incorrect ConnackCode error value.") +} + +func TestFixedHeaderFlags(t *testing.T) { + type detail struct { + name string + flags byte + } + + details := map[MessageType]detail{ + RESERVED: detail{"RESERVED", 0}, + CONNECT: detail{"CONNECT", 0}, + CONNACK: detail{"CONNACK", 0}, + PUBLISH: detail{"PUBLISH", 0}, + PUBACK: detail{"PUBACK", 0}, + PUBREC: detail{"PUBREC", 0}, + PUBREL: detail{"PUBREL", 2}, + PUBCOMP: detail{"PUBCOMP", 0}, + SUBSCRIBE: detail{"SUBSCRIBE", 2}, + SUBACK: detail{"SUBACK", 0}, + UNSUBSCRIBE: detail{"UNSUBSCRIBE", 2}, + UNSUBACK: detail{"UNSUBACK", 0}, + PINGREQ: detail{"PINGREQ", 0}, + PINGRESP: detail{"PINGRESP", 0}, + DISCONNECT: detail{"DISCONNECT", 0}, + RESERVED2: detail{"RESERVED2", 0}, + } + + for m, d := range details { + if m.Name() != d.name { + t.Errorf("Name mismatch. Expecting %s, got %s.", d.name, m.Name()) + } + + if m.DefaultFlags() != d.flags { + t.Errorf("Flag mismatch for %s. Expecting %d, got %d.", m.Name(), d.flags, m.DefaultFlags()) + } + } +} + +func TestSupportedVersions(t *testing.T) { + for k, v := range SupportedVersions { + if k == 0x03 && v != "MQIsdp" { + t.Errorf("Protocol version and name mismatch. Expect %s, got %s.", "MQIsdp", v) + } + } +} diff --git a/ping_test.go b/ping_test.go new file mode 100644 index 0000000..76835c8 --- /dev/null +++ b/ping_test.go @@ -0,0 +1,135 @@ +// Copyright (c) 2014 Dataence, LLC. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package message + +import ( + "testing" + + "github.com/dataence/assert" +) + +func TestPingreqMessageDecode(t *testing.T) { + msgBytes := []byte{ + byte(PINGREQ << 4), + 0, + } + + msg := NewPingreqMessage() + n, err := msg.Decode(msgBytes) + + assert.NoError(t, true, err, "Error decoding message.") + assert.Equal(t, true, len(msgBytes), n, "Error decoding message.") + assert.Equal(t, true, PINGREQ, msg.Type(), "Error decoding message.") +} + +func TestPingreqMessageEncode(t *testing.T) { + msgBytes := []byte{ + byte(PINGREQ << 4), + 0, + } + + msg := NewPingreqMessage() + + dst := make([]byte, 10) + n, err := msg.Encode(dst) + + assert.NoError(t, true, err, "Error decoding message.") + assert.Equal(t, true, len(msgBytes), n, "Error decoding message.") + assert.Equal(t, true, msgBytes, dst[:n], "Error decoding message.") +} + +func TestPingrespMessageDecode(t *testing.T) { + msgBytes := []byte{ + byte(PINGRESP << 4), + 0, + } + + msg := NewPingrespMessage() + n, err := msg.Decode(msgBytes) + + assert.NoError(t, true, err, "Error decoding message.") + assert.Equal(t, true, len(msgBytes), n, "Error decoding message.") + assert.Equal(t, true, PINGRESP, msg.Type(), "Error decoding message.") +} + +func TestPingrespMessageEncode(t *testing.T) { + msgBytes := []byte{ + byte(PINGRESP << 4), + 0, + } + + msg := NewPingrespMessage() + + dst := make([]byte, 10) + n, err := msg.Encode(dst) + + assert.NoError(t, true, err, "Error decoding message.") + assert.Equal(t, true, len(msgBytes), n, "Error decoding message.") + assert.Equal(t, true, msgBytes, dst[:n], "Error decoding message.") +} + +// test to ensure encoding and decoding are the same +// decode, encode, and decode again +func TestPingreqDecodeEncodeEquiv(t *testing.T) { + msgBytes := []byte{ + byte(PINGREQ << 4), + 0, + } + + msg := NewPingreqMessage() + n, err := msg.Decode(msgBytes) + + assert.NoError(t, true, err, "Error decoding message.") + assert.Equal(t, true, len(msgBytes), n, "Error decoding message.") + + dst := make([]byte, 100) + n2, err := msg.Encode(dst) + + assert.NoError(t, true, err, "Error decoding message.") + assert.Equal(t, true, len(msgBytes), n2, "Error decoding message.") + assert.Equal(t, true, msgBytes, dst[:n2], "Error decoding message.") + + n3, err := msg.Decode(dst) + + assert.NoError(t, true, err, "Error decoding message.") + assert.Equal(t, true, len(msgBytes), n3, "Error decoding message.") +} + +// test to ensure encoding and decoding are the same +// decode, encode, and decode again +func TestPingrespDecodeEncodeEquiv(t *testing.T) { + msgBytes := []byte{ + byte(PINGRESP << 4), + 0, + } + + msg := NewPingrespMessage() + n, err := msg.Decode(msgBytes) + + assert.NoError(t, true, err, "Error decoding message.") + assert.Equal(t, true, len(msgBytes), n, "Error decoding message.") + + dst := make([]byte, 100) + n2, err := msg.Encode(dst) + + assert.NoError(t, true, err, "Error decoding message.") + assert.Equal(t, true, len(msgBytes), n2, "Error decoding message.") + assert.Equal(t, true, msgBytes, dst[:n2], "Error decoding message.") + + n3, err := msg.Decode(dst) + + assert.NoError(t, true, err, "Error decoding message.") + assert.Equal(t, true, len(msgBytes), n3, "Error decoding message.") +} diff --git a/pingreq.go b/pingreq.go new file mode 100644 index 0000000..40dba34 --- /dev/null +++ b/pingreq.go @@ -0,0 +1,42 @@ +// Copyright (c) 2014 Dataence, LLC. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package message + +// The PINGREQ Packet is sent from a Client to the Server. It can be used to: +// 1. Indicate to the Server that the Client is alive in the absence of any other +// Control Packets being sent from the Client to the Server. +// 2. Request that the Server responds to confirm that it is alive. +// 3. Exercise the network to indicate that the Network Connection is active. +type PingreqMessage struct { + header +} + +var _ Message = (*PingreqMessage)(nil) + +// NewPingreqMessage creates a new PINGREQ message. +func NewPingreqMessage() *PingreqMessage { + msg := &PingreqMessage{} + msg.SetType(PINGREQ) + + return msg +} + +func (this *PingreqMessage) Decode(src []byte) (int, error) { + return this.header.decode(src) +} + +func (this *PingreqMessage) Encode(dst []byte) (int, error) { + return this.header.encode(dst) +} diff --git a/pingresp.go b/pingresp.go new file mode 100644 index 0000000..049fd01 --- /dev/null +++ b/pingresp.go @@ -0,0 +1,39 @@ +// Copyright (c) 2014 Dataence, LLC. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package message + +// A PINGRESP Packet is sent by the Server to the Client in response to a PINGREQ +// Packet. It indicates that the Server is alive. +type PingrespMessage struct { + header +} + +var _ Message = (*PingrespMessage)(nil) + +// NewPingrespMessage creates a new PINGRESP message. +func NewPingrespMessage() *PingrespMessage { + msg := &PingrespMessage{} + msg.SetType(PINGRESP) + + return msg +} + +func (this *PingrespMessage) Decode(src []byte) (int, error) { + return this.header.decode(src) +} + +func (this *PingrespMessage) Encode(dst []byte) (int, error) { + return this.header.encode(dst) +} diff --git a/puback.go b/puback.go new file mode 100644 index 0000000..246535c --- /dev/null +++ b/puback.go @@ -0,0 +1,91 @@ +// Copyright (c) 2014 Dataence, LLC. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package message + +import ( + "encoding/binary" + "fmt" +) + +// A PUBACK Packet is the response to a PUBLISH Packet with QoS level 1. +type PubackMessage struct { + header +} + +var _ Message = (*PubackMessage)(nil) + +// NewPubackMessage creates a new PUBACK message. +func NewPubackMessage() *PubackMessage { + msg := &PubackMessage{} + msg.SetType(PUBACK) + + return msg +} + +func (this *PubackMessage) Len() int { + ml := this.msglen() + + if err := this.SetRemainingLength(int32(ml)); err != nil { + return 0 + } + + return this.header.msglen() + ml +} + +func (this *PubackMessage) Decode(src []byte) (int, error) { + total := 0 + + n, err := this.header.decode(src[total:]) + total += n + if err != nil { + return total, err + } + + this.packetId = binary.BigEndian.Uint16(src[total:]) + total += 2 + + return total, nil +} + +func (this *PubackMessage) Encode(dst []byte) (int, error) { + hl := this.header.msglen() + ml := this.msglen() + + if len(dst) < hl+ml { + return 0, fmt.Errorf("puback/Encode: Insufficient buffer size. Expecting %d, got %d.", hl+ml, len(dst)) + } + + if err := this.SetRemainingLength(int32(ml)); err != nil { + return 0, err + } + + total := 0 + + n, err := this.header.encode(dst[total:]) + total += n + if err != nil { + return total, err + } + + binary.BigEndian.PutUint16(dst[total:], this.packetId) + total += 2 + + return total, nil +} + +func (this *PubackMessage) msglen() int { + // packet ID + return 2 +} diff --git a/puback_test.go b/puback_test.go new file mode 100644 index 0000000..843dc7f --- /dev/null +++ b/puback_test.go @@ -0,0 +1,108 @@ +// Copyright (c) 2014 Dataence, LLC. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package message + +import ( + "testing" + + "github.com/dataence/assert" +) + +func TestPubackMessageFields(t *testing.T) { + msg := NewPubackMessage() + + msg.SetPacketId(100) + + assert.Equal(t, true, 100, msg.PacketId()) +} + +func TestPubackMessageDecode(t *testing.T) { + msgBytes := []byte{ + byte(PUBACK << 4), + 2, + 0, // packet ID MSB (0) + 7, // packet ID LSB (7) + } + + msg := NewPubackMessage() + n, err := msg.Decode(msgBytes) + + assert.NoError(t, true, err, "Error decoding message.") + assert.Equal(t, true, len(msgBytes), n, "Error decoding message.") + assert.Equal(t, true, PUBACK, msg.Type(), "Error decoding message.") + assert.Equal(t, true, 7, msg.PacketId(), "Error decoding message.") +} + +// test insufficient bytes +func TestPubackMessageDecode2(t *testing.T) { + msgBytes := []byte{ + byte(PUBACK << 4), + 2, + 7, // packet ID LSB (7) + } + + msg := NewPubackMessage() + _, err := msg.Decode(msgBytes) + + assert.Error(t, true, err) +} + +func TestPubackMessageEncode(t *testing.T) { + msgBytes := []byte{ + byte(PUBACK << 4), + 2, + 0, // packet ID MSB (0) + 7, // packet ID LSB (7) + } + + msg := NewPubackMessage() + msg.SetPacketId(7) + + dst := make([]byte, 10) + n, err := msg.Encode(dst) + + assert.NoError(t, true, err, "Error decoding message.") + assert.Equal(t, true, len(msgBytes), n, "Error decoding message.") + assert.Equal(t, true, msgBytes, dst[:n], "Error decoding message.") +} + +// test to ensure encoding and decoding are the same +// decode, encode, and decode again +func TestPubackDecodeEncodeEquiv(t *testing.T) { + msgBytes := []byte{ + byte(PUBACK << 4), + 2, + 0, // packet ID MSB (0) + 7, // packet ID LSB (7) + } + + msg := NewPubackMessage() + n, err := msg.Decode(msgBytes) + + assert.NoError(t, true, err, "Error decoding message.") + assert.Equal(t, true, len(msgBytes), n, "Error decoding message.") + + dst := make([]byte, 100) + n2, err := msg.Encode(dst) + + assert.NoError(t, true, err, "Error decoding message.") + assert.Equal(t, true, len(msgBytes), n2, "Error decoding message.") + assert.Equal(t, true, msgBytes, dst[:n2], "Error decoding message.") + + n3, err := msg.Decode(dst) + + assert.NoError(t, true, err, "Error decoding message.") + assert.Equal(t, true, len(msgBytes), n3, "Error decoding message.") +} diff --git a/pubcomp.go b/pubcomp.go new file mode 100644 index 0000000..22f4c5f --- /dev/null +++ b/pubcomp.go @@ -0,0 +1,31 @@ +// Copyright (c) 2014 Dataence, LLC. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package message + +// The PUBCOMP Packet is the response to a PUBREL Packet. It is the fourth and +// final packet of the QoS 2 protocol exchange. +type PubcompMessage struct { + PubackMessage +} + +var _ Message = (*PubcompMessage)(nil) + +// NewPubcompMessage creates a new PUBCOMP message. +func NewPubcompMessage() *PubcompMessage { + msg := &PubcompMessage{} + msg.SetType(PUBCOMP) + + return msg +} diff --git a/pubcomp_test.go b/pubcomp_test.go new file mode 100644 index 0000000..39cac37 --- /dev/null +++ b/pubcomp_test.go @@ -0,0 +1,108 @@ +// Copyright (c) 2014 Dataence, LLC. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package message + +import ( + "testing" + + "github.com/dataence/assert" +) + +func TestPubcompMessageFields(t *testing.T) { + msg := NewPubcompMessage() + + msg.SetPacketId(100) + + assert.Equal(t, true, 100, msg.PacketId()) +} + +func TestPubcompMessageDecode(t *testing.T) { + msgBytes := []byte{ + byte(PUBCOMP << 4), + 2, + 0, // packet ID MSB (0) + 7, // packet ID LSB (7) + } + + msg := NewPubcompMessage() + n, err := msg.Decode(msgBytes) + + assert.NoError(t, true, err, "Error decoding message.") + assert.Equal(t, true, len(msgBytes), n, "Error decoding message.") + assert.Equal(t, true, PUBCOMP, msg.Type(), "Error decoding message.") + assert.Equal(t, true, 7, msg.PacketId(), "Error decoding message.") +} + +// test insufficient bytes +func TestPubcompMessageDecode2(t *testing.T) { + msgBytes := []byte{ + byte(PUBCOMP << 4), + 2, + 7, // packet ID LSB (7) + } + + msg := NewPubcompMessage() + _, err := msg.Decode(msgBytes) + + assert.Error(t, true, err) +} + +func TestPubcompMessageEncode(t *testing.T) { + msgBytes := []byte{ + byte(PUBCOMP << 4), + 2, + 0, // packet ID MSB (0) + 7, // packet ID LSB (7) + } + + msg := NewPubcompMessage() + msg.SetPacketId(7) + + dst := make([]byte, 10) + n, err := msg.Encode(dst) + + assert.NoError(t, true, err, "Error decoding message.") + assert.Equal(t, true, len(msgBytes), n, "Error decoding message.") + assert.Equal(t, true, msgBytes, dst[:n], "Error decoding message.") +} + +// test to ensure encoding and decoding are the same +// decode, encode, and decode again +func TestPubcompDecodeEncodeEquiv(t *testing.T) { + msgBytes := []byte{ + byte(PUBCOMP << 4), + 2, + 0, // packet ID MSB (0) + 7, // packet ID LSB (7) + } + + msg := NewPubcompMessage() + n, err := msg.Decode(msgBytes) + + assert.NoError(t, true, err, "Error decoding message.") + assert.Equal(t, true, len(msgBytes), n, "Error decoding message.") + + dst := make([]byte, 100) + n2, err := msg.Encode(dst) + + assert.NoError(t, true, err, "Error decoding message.") + assert.Equal(t, true, len(msgBytes), n2, "Error decoding message.") + assert.Equal(t, true, msgBytes, dst[:n2], "Error decoding message.") + + n3, err := msg.Decode(dst) + + assert.NoError(t, true, err, "Error decoding message.") + assert.Equal(t, true, len(msgBytes), n3, "Error decoding message.") +} diff --git a/publish.go b/publish.go new file mode 100644 index 0000000..b1baf50 --- /dev/null +++ b/publish.go @@ -0,0 +1,230 @@ +// Copyright (c) 2014 Dataence, LLC. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package message + +import ( + "encoding/binary" + "fmt" + "sync/atomic" +) + +// A PUBLISH Control Packet is sent from a Client to a Server or from Server to a Client +// to transport an Application Message. +type PublishMessage struct { + header + + topic []byte + payload []byte +} + +var _ Message = (*PublishMessage)(nil) + +// NewPublishMessage creates a new PUBLISH message. +func NewPublishMessage() *PublishMessage { + msg := &PublishMessage{} + msg.SetType(PUBLISH) + + return msg +} + +func (this PublishMessage) String() string { + return fmt.Sprintf("%v\nTopic: %s\nPacket ID: %d\nPayload: %s\n", + this.header, this.topic, this.packetId, string(this.payload)) +} + +// Dup returns the value specifying the duplicate delivery of a PUBLISH Control Packet. +// If the DUP flag is set to 0, it indicates that this is the first occasion that the +// Client or Server has attempted to send this MQTT PUBLISH Packet. If the DUP flag is +// set to 1, it indicates that this might be re-delivery of an earlier attempt to send +// the Packet. +func (this *PublishMessage) Dup() bool { + return ((this.flags >> 3) & 0x1) == 1 +} + +// SetDup sets the value specifying the duplicate delivery of a PUBLISH Control Packet. +func (this *PublishMessage) SetDup(v bool) { + if v { + this.flags |= 0x8 // 00001000 + } else { + this.flags &= 247 // 11110111 + } +} + +// Retain returns the value of the RETAIN flag. This flag is only used on the PUBLISH +// Packet. If the RETAIN flag is set to 1, in a PUBLISH Packet sent by a Client to a +// Server, the Server MUST store the Application Message and its QoS, so that it can be +// delivered to future subscribers whose subscriptions match its topic name. +func (this *PublishMessage) Retain() bool { + return (this.flags & 0x1) == 1 +} + +// SetRetain sets the value of the RETAIN flag. +func (this *PublishMessage) SetRetain(v bool) { + if v { + this.flags |= 0x1 // 00000001 + } else { + this.flags &= 254 // 11111110 + } +} + +// QoS returns the field that indicates the level of assurance for delivery of an +// Application Message. The values are QosAtMostOnce, QosAtLeastOnce and QosExactlyOnce. +func (this *PublishMessage) QoS() byte { + return (this.flags >> 1) & 0x3 +} + +// SetQoS sets the field that indicates the level of assurance for delivery of an +// Application Message. The values are QosAtMostOnce, QosAtLeastOnce and QosExactlyOnce. +// An error is returned if the value is not one of these. +func (this *PublishMessage) SetQoS(v byte) error { + if v != 0x0 && v != 0x1 && v != 0x2 { + return fmt.Errorf("publish/SetQoS: Invalid QoS %d.", v) + } + + this.flags = (this.flags & 249) | (v << 1) // 243 = 11111001 + return nil +} + +// Topic returns the the topic name that identifies the information channel to which +// payload data is published. +func (this *PublishMessage) Topic() []byte { + return this.topic +} + +// SetTopic sets the the topic name that identifies the information channel to which +// payload data is published. An error is returned if ValidTopic() is falbase. +func (this *PublishMessage) SetTopic(v []byte) error { + if !ValidTopic(v) { + return fmt.Errorf("publish/SetTopic: Invalid topic name (%s). Must not be empty or contain wildcard characters", string(v)) + } + + this.topic = v + return nil +} + +// Payload returns the application message that's part of the PUBLISH message. +func (this *PublishMessage) Payload() []byte { + return this.payload +} + +// SetPayload sets the application message that's part of the PUBLISH message. +func (this *PublishMessage) SetPayload(v []byte) { + this.payload = v +} + +func (this *PublishMessage) Len() int { + ml := this.msglen() + + if err := this.SetRemainingLength(int32(ml)); err != nil { + return 0 + } + + return this.header.msglen() + ml +} + +func (this *PublishMessage) Decode(src []byte) (int, error) { + total := 0 + + hn, err := this.header.decode(src[total:]) + total += hn + if err != nil { + return total, err + } + + n := 0 + + this.topic, n, err = readLPBytes(src[total:]) + total += n + if err != nil { + return total, err + } + + if !ValidTopic(this.topic) { + return total, fmt.Errorf("publish/Decode: Invalid topic name (%s). Must not be empty or contain wildcard characters", string(this.topic)) + } + + // The packet identifier field is only present in the PUBLISH packets where the + // QoS level is 1 or 2 + if this.QoS() != 0 { + this.packetId = binary.BigEndian.Uint16(src[total:]) + total += 2 + } + + l := int(this.remlen) - (total - hn) + this.payload = src[total : total+l] + total += len(this.payload) + + return total, nil +} + +func (this *PublishMessage) Encode(dst []byte) (int, error) { + if len(this.topic) == 0 { + return 0, fmt.Errorf("publish/Encode: Topic name is empty.") + } + + if len(this.payload) == 0 { + return 0, fmt.Errorf("publish/Encode: Payload is empty.") + } + + ml := this.msglen() + + if err := this.SetRemainingLength(int32(ml)); err != nil { + return 0, err + } + + hl := this.header.msglen() + + if len(dst) < hl+ml { + return 0, fmt.Errorf("publish/Encode: Insufficient buffer size. Expecting %d, got %d.", hl+ml, len(dst)) + } + + total := 0 + + n, err := this.header.encode(dst[total:]) + total += n + if err != nil { + return total, err + } + + n, err = writeLPBytes(dst[total:], this.topic) + total += n + if err != nil { + return total, err + } + + // The packet identifier field is only present in the PUBLISH packets where the QoS level is 1 or 2 + if this.QoS() != 0 { + if this.packetId == 0 { + this.packetId = uint16(atomic.AddUint64(&gPacketId, 1) & 0xffff) + } + + binary.BigEndian.PutUint16(dst[total:], this.packetId) + total += 2 + } + + copy(dst[total:], this.payload) + total += len(this.payload) + + return total, nil +} + +func (this *PublishMessage) msglen() int { + total := 2 + len(this.topic) + len(this.payload) + if this.QoS() != 0 { + total += 2 + } + + return total +} diff --git a/publish_test.go b/publish_test.go new file mode 100644 index 0000000..fd4bf96 --- /dev/null +++ b/publish_test.go @@ -0,0 +1,284 @@ +// Copyright (c) 2014 Dataence, LLC. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package message + +import ( + "testing" + + "github.com/dataence/assert" + "github.com/dataence/glog" +) + +func TestPublishMessageHeaderFields(t *testing.T) { + msg := NewPublishMessage() + msg.flags = 11 + + assert.True(t, true, msg.Dup(), "Incorrect DUP flag.") + assert.True(t, true, msg.Retain(), "Incorrect RETAIN flag.") + assert.Equal(t, true, 1, msg.QoS(), "Incorrect QoS.") + + msg.SetDup(false) + + assert.False(t, true, msg.Dup(), "Incorrect DUP flag.") + + msg.SetRetain(false) + + assert.False(t, true, msg.Retain(), "Incorrect RETAIN flag.") + + err := msg.SetQoS(2) + + assert.NoError(t, true, err, "Error setting QoS.") + assert.Equal(t, true, 2, msg.QoS(), "Incorrect QoS.") + + err = msg.SetQoS(3) + + assert.Error(t, true, err) + + err = msg.SetQoS(0) + + assert.NoError(t, true, err, "Error setting QoS.") + assert.Equal(t, true, 0, msg.QoS(), "Incorrect QoS.") + + msg.SetDup(true) + + assert.True(t, true, msg.Dup(), "Incorrect DUP flag.") + + msg.SetRetain(true) + + assert.True(t, true, msg.Retain(), "Incorrect RETAIN flag.") +} + +func TestPublishMessageFields(t *testing.T) { + msg := NewPublishMessage() + + msg.SetTopic([]byte("coolstuff")) + + assert.Equal(t, true, "coolstuff", string(msg.Topic()), "Error setting message topic.") + + err := msg.SetTopic([]byte("coolstuff/#")) + + assert.Error(t, true, err) + + msg.SetPacketId(100) + + assert.Equal(t, true, 100, msg.PacketId(), "Error setting acket ID.") + + msg.SetPayload([]byte("this is a payload to be sent")) + + assert.Equal(t, true, []byte("this is a payload to be sent"), msg.Payload(), "Error setting payload.") +} + +func TestPublishMessageDecode1(t *testing.T) { + msgBytes := []byte{ + byte(PUBLISH<<4) | 2, + 23, + 0, // topic name MSB (0) + 7, // topic name LSB (7) + 's', 'u', 'r', 'g', 'e', 'm', 'q', + 0, // packet ID MSB (0) + 7, // packet ID LSB (7) + 's', 'e', 'n', 'd', ' ', 'm', 'e', ' ', 'h', 'o', 'm', 'e', + } + + msg := NewPublishMessage() + n, err := msg.Decode(msgBytes) + + assert.NoError(t, true, err, "Error decoding message.") + assert.Equal(t, true, len(msgBytes), n, "Error decoding message.") + assert.Equal(t, true, 7, msg.PacketId(), "Error decoding message.") + assert.Equal(t, true, "surgemq", string(msg.Topic()), "Error deocding topic name.") + assert.Equal(t, true, []byte{'s', 'e', 'n', 'd', ' ', 'm', 'e', ' ', 'h', 'o', 'm', 'e'}, msg.Payload(), "Error deocding payload.") +} + +// test insufficient bytes +func TestPublishMessageDecode2(t *testing.T) { + msgBytes := []byte{ + byte(PUBLISH<<4) | 2, + 26, + 0, // topic name MSB (0) + 7, // topic name LSB (7) + 's', 'u', 'r', 'g', 'e', 'm', 'q', + 0, // packet ID MSB (0) + 7, // packet ID LSB (7) + 's', 'e', 'n', 'd', ' ', 'm', 'e', ' ', 'h', 'o', 'm', 'e', + } + + msg := NewPublishMessage() + _, err := msg.Decode(msgBytes) + + assert.Error(t, true, err) +} + +// test qos = 0 and no client id +func TestPublishMessageDecode3(t *testing.T) { + msgBytes := []byte{ + byte(PUBLISH << 4), + 21, + 0, // topic name MSB (0) + 7, // topic name LSB (7) + 's', 'u', 'r', 'g', 'e', 'm', 'q', + 's', 'e', 'n', 'd', ' ', 'm', 'e', ' ', 'h', 'o', 'm', 'e', + } + + msg := NewPublishMessage() + _, err := msg.Decode(msgBytes) + + assert.NoError(t, true, err, "Error decoding message.") +} + +func TestPublishMessageEncode(t *testing.T) { + msgBytes := []byte{ + byte(PUBLISH<<4) | 2, + 23, + 0, // topic name MSB (0) + 7, // topic name LSB (7) + 's', 'u', 'r', 'g', 'e', 'm', 'q', + 0, // packet ID MSB (0) + 7, // packet ID LSB (7) + 's', 'e', 'n', 'd', ' ', 'm', 'e', ' ', 'h', 'o', 'm', 'e', + } + + msg := NewPublishMessage() + msg.SetTopic([]byte("surgemq")) + msg.SetQoS(1) + msg.SetPacketId(7) + msg.SetPayload([]byte{'s', 'e', 'n', 'd', ' ', 'm', 'e', ' ', 'h', 'o', 'm', 'e'}) + + dst := make([]byte, 100) + n, err := msg.Encode(dst) + + assert.NoError(t, true, err, "Error decoding message.") + assert.Equal(t, true, len(msgBytes), n, "Error decoding message.") + assert.Equal(t, true, msgBytes, dst[:n], "Error decoding message.") +} + +// test empty topic name +func TestPublishMessageEncode2(t *testing.T) { + msg := NewPublishMessage() + msg.SetTopic([]byte("")) + msg.SetPacketId(7) + msg.SetPayload([]byte{'s', 'e', 'n', 'd', ' ', 'm', 'e', ' ', 'h', 'o', 'm', 'e'}) + + dst := make([]byte, 100) + _, err := msg.Encode(dst) + assert.Error(t, true, err) +} + +// test encoding qos = 0 and no packet id +func TestPublishMessageEncode3(t *testing.T) { + msgBytes := []byte{ + byte(PUBLISH << 4), + 21, + 0, // topic name MSB (0) + 7, // topic name LSB (7) + 's', 'u', 'r', 'g', 'e', 'm', 'q', + 's', 'e', 'n', 'd', ' ', 'm', 'e', ' ', 'h', 'o', 'm', 'e', + } + + msg := NewPublishMessage() + msg.SetTopic([]byte("surgemq")) + msg.SetQoS(0) + msg.SetPayload([]byte{'s', 'e', 'n', 'd', ' ', 'm', 'e', ' ', 'h', 'o', 'm', 'e'}) + + dst := make([]byte, 100) + n, err := msg.Encode(dst) + + assert.NoError(t, true, err, "Error decoding message.") + assert.Equal(t, true, len(msgBytes), n, "Error decoding message.") + assert.Equal(t, true, msgBytes, dst[:n], "Error decoding message.") +} + +// test large message +func TestPublishMessageEncode4(t *testing.T) { + msgBytes := []byte{ + byte(PUBLISH << 4), + 137, + 8, + 0, // topic name MSB (0) + 7, // topic name LSB (7) + 's', 'u', 'r', 'g', 'e', 'm', 'q', + } + + payload := make([]byte, 1024) + msgBytes = append(msgBytes, payload...) + + glog.Debugf("msgBytes len = %d", len(msgBytes)) + + msg := NewPublishMessage() + msg.SetTopic([]byte("surgemq")) + msg.SetQoS(0) + msg.SetPayload(payload) + + assert.Equal(t, true, len(msgBytes), msg.Len()) + + dst := make([]byte, 1100) + n, err := msg.Encode(dst) + + assert.NoError(t, true, err, "Error decoding message.") + assert.Equal(t, true, len(msgBytes), n, "Error decoding message.") + assert.Equal(t, true, msgBytes, dst[:n], "Error decoding message.") +} + +// test from github issue #2, @mrdg +func TestPublishDecodeEncodeEquiv2(t *testing.T) { + msgBytes := []byte{50, 18, 0, 9, 103, 114, 101, 101, 116, 105, 110, 103, 115, 0, 1, 72, 101, 108, 108, 111} + + msg := NewPublishMessage() + n, err := msg.Decode(msgBytes) + + assert.NoError(t, true, err, "Error decoding message.") + assert.Equal(t, true, len(msgBytes), n, "Error decoding message.") + + dst := make([]byte, 100) + n2, err := msg.Encode(dst) + + assert.NoError(t, true, err, "Error decoding message.") + assert.Equal(t, true, len(msgBytes), n2, "Error decoding message.") + assert.Equal(t, true, msgBytes, dst[:n], "Error decoding message.") +} + +// test to ensure encoding and decoding are the same +// decode, encode, and decode again +func TestPublishDecodeEncodeEquiv(t *testing.T) { + msgBytes := []byte{ + byte(PUBLISH<<4) | 2, + 23, + 0, // topic name MSB (0) + 7, // topic name LSB (7) + 's', 'u', 'r', 'g', 'e', 'm', 'q', + 0, // packet ID MSB (0) + 7, // packet ID LSB (7) + 's', 'e', 'n', 'd', ' ', 'm', 'e', ' ', 'h', 'o', 'm', 'e', + } + + msg := NewPublishMessage() + + n, err := msg.Decode(msgBytes) + + assert.NoError(t, true, err, "Error decoding message.") + assert.Equal(t, true, len(msgBytes), n, "Error decoding message.") + + dst := make([]byte, 100) + n2, err := msg.Encode(dst) + + assert.NoError(t, true, err, "Error decoding message.") + assert.Equal(t, true, len(msgBytes), n2, "Error decoding message.") + assert.Equal(t, true, msgBytes, dst[:n2], "Error decoding message.") + + n3, err := msg.Decode(dst) + + assert.NoError(t, true, err, "Error decoding message.") + assert.Equal(t, true, len(msgBytes), n3, "Error decoding message.") +} diff --git a/pubrec.go b/pubrec.go new file mode 100644 index 0000000..6a3daff --- /dev/null +++ b/pubrec.go @@ -0,0 +1,31 @@ +// Copyright (c) 2014 Dataence, LLC. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package message + +type PubrecMessage struct { + PubackMessage +} + +// A PUBREC Packet is the response to a PUBLISH Packet with QoS 2. It is the second +// packet of the QoS 2 protocol exchange. +var _ Message = (*PubrecMessage)(nil) + +// NewPubrecMessage creates a new PUBREC message. +func NewPubrecMessage() *PubrecMessage { + msg := &PubrecMessage{} + msg.SetType(PUBREC) + + return msg +} diff --git a/pubrec_test.go b/pubrec_test.go new file mode 100644 index 0000000..30ac681 --- /dev/null +++ b/pubrec_test.go @@ -0,0 +1,108 @@ +// Copyright (c) 2014 Dataence, LLC. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package message + +import ( + "testing" + + "github.com/dataence/assert" +) + +func TestPubrecMessageFields(t *testing.T) { + msg := NewPubrecMessage() + + msg.SetPacketId(100) + + assert.Equal(t, true, 100, msg.PacketId()) +} + +func TestPubrecMessageDecode(t *testing.T) { + msgBytes := []byte{ + byte(PUBREC << 4), + 2, + 0, // packet ID MSB (0) + 7, // packet ID LSB (7) + } + + msg := NewPubrecMessage() + n, err := msg.Decode(msgBytes) + + assert.NoError(t, true, err, "Error decoding message.") + assert.Equal(t, true, len(msgBytes), n, "Error decoding message.") + assert.Equal(t, true, PUBREC, msg.Type(), "Error decoding message.") + assert.Equal(t, true, 7, msg.PacketId(), "Error decoding message.") +} + +// test insufficient bytes +func TestPubrecMessageDecode2(t *testing.T) { + msgBytes := []byte{ + byte(PUBREC << 4), + 2, + 7, // packet ID LSB (7) + } + + msg := NewPubrecMessage() + _, err := msg.Decode(msgBytes) + + assert.Error(t, true, err) +} + +func TestPubrecMessageEncode(t *testing.T) { + msgBytes := []byte{ + byte(PUBREC << 4), + 2, + 0, // packet ID MSB (0) + 7, // packet ID LSB (7) + } + + msg := NewPubrecMessage() + msg.SetPacketId(7) + + dst := make([]byte, 10) + n, err := msg.Encode(dst) + + assert.NoError(t, true, err, "Error decoding message.") + assert.Equal(t, true, len(msgBytes), n, "Error decoding message.") + assert.Equal(t, true, msgBytes, dst[:n], "Error decoding message.") +} + +// test to ensure encoding and decoding are the same +// decode, encode, and decode again +func TestPubrecDecodeEncodeEquiv(t *testing.T) { + msgBytes := []byte{ + byte(PUBREC << 4), + 2, + 0, // packet ID MSB (0) + 7, // packet ID LSB (7) + } + + msg := NewPubrecMessage() + n, err := msg.Decode(msgBytes) + + assert.NoError(t, true, err, "Error decoding message.") + assert.Equal(t, true, len(msgBytes), n, "Error decoding message.") + + dst := make([]byte, 100) + n2, err := msg.Encode(dst) + + assert.NoError(t, true, err, "Error decoding message.") + assert.Equal(t, true, len(msgBytes), n2, "Error decoding message.") + assert.Equal(t, true, msgBytes, dst[:n2], "Error decoding message.") + + n3, err := msg.Decode(dst) + + assert.NoError(t, true, err, "Error decoding message.") + assert.Equal(t, true, len(msgBytes), n3, "Error decoding message.") +} diff --git a/pubrel.go b/pubrel.go new file mode 100644 index 0000000..ab0f04b --- /dev/null +++ b/pubrel.go @@ -0,0 +1,31 @@ +// Copyright (c) 2014 Dataence, LLC. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package message + +// A PUBREL Packet is the response to a PUBREC Packet. It is the third packet of the +// QoS 2 protocol exchange. +type PubrelMessage struct { + PubackMessage +} + +var _ Message = (*PubrelMessage)(nil) + +// NewPubrelMessage creates a new PUBREL message. +func NewPubrelMessage() *PubrelMessage { + msg := &PubrelMessage{} + msg.SetType(PUBREL) + + return msg +} diff --git a/pubrel_test.go b/pubrel_test.go new file mode 100644 index 0000000..ddce3f6 --- /dev/null +++ b/pubrel_test.go @@ -0,0 +1,108 @@ +// Copyright (c) 2014 Dataence, LLC. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package message + +import ( + "testing" + + "github.com/dataence/assert" +) + +func TestPubrelMessageFields(t *testing.T) { + msg := NewPubrelMessage() + + msg.SetPacketId(100) + + assert.Equal(t, true, 100, msg.PacketId()) +} + +func TestPubrelMessageDecode(t *testing.T) { + msgBytes := []byte{ + byte(PUBREL<<4) | 2, + 2, + 0, // packet ID MSB (0) + 7, // packet ID LSB (7) + } + + msg := NewPubrelMessage() + n, err := msg.Decode(msgBytes) + + assert.NoError(t, true, err, "Error decoding message.") + assert.Equal(t, true, len(msgBytes), n, "Error decoding message.") + assert.Equal(t, true, PUBREL, msg.Type(), "Error decoding message.") + assert.Equal(t, true, 7, msg.PacketId(), "Error decoding message.") +} + +// test insufficient bytes +func TestPubrelMessageDecode2(t *testing.T) { + msgBytes := []byte{ + byte(PUBREL<<4) | 2, + 2, + 7, // packet ID LSB (7) + } + + msg := NewPubrelMessage() + _, err := msg.Decode(msgBytes) + + assert.Error(t, true, err) +} + +func TestPubrelMessageEncode(t *testing.T) { + msgBytes := []byte{ + byte(PUBREL<<4) | 2, + 2, + 0, // packet ID MSB (0) + 7, // packet ID LSB (7) + } + + msg := NewPubrelMessage() + msg.SetPacketId(7) + + dst := make([]byte, 10) + n, err := msg.Encode(dst) + + assert.NoError(t, true, err, "Error decoding message.") + assert.Equal(t, true, len(msgBytes), n, "Error decoding message.") + assert.Equal(t, true, msgBytes, dst[:n], "Error decoding message.") +} + +// test to ensure encoding and decoding are the same +// decode, encode, and decode again +func TestPubrelDecodeEncodeEquiv(t *testing.T) { + msgBytes := []byte{ + byte(PUBREL<<4) | 2, + 2, + 0, // packet ID MSB (0) + 7, // packet ID LSB (7) + } + + msg := NewPubrelMessage() + n, err := msg.Decode(msgBytes) + + assert.NoError(t, true, err, "Error decoding message.") + assert.Equal(t, true, len(msgBytes), n, "Error decoding message.") + + dst := make([]byte, 100) + n2, err := msg.Encode(dst) + + assert.NoError(t, true, err, "Error decoding message.") + assert.Equal(t, true, len(msgBytes), n2, "Error decoding message.") + assert.Equal(t, true, msgBytes, dst[:n2], "Error decoding message.") + + n3, err := msg.Decode(dst) + + assert.NoError(t, true, err, "Error decoding message.") + assert.Equal(t, true, len(msgBytes), n3, "Error decoding message.") +} diff --git a/suback.go b/suback.go new file mode 100644 index 0000000..6449e93 --- /dev/null +++ b/suback.go @@ -0,0 +1,144 @@ +// Copyright (c) 2014 Dataence, LLC. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package message + +import ( + "encoding/binary" + "fmt" +) + +// A SUBACK Packet is sent by the Server to the Client to confirm receipt and processing +// of a SUBSCRIBE Packet. +// +// A SUBACK Packet contains a list of return codes, that specify the maximum QoS level +// that was granted in each Subscription that was requested by the SUBSCRIBE. +type SubackMessage struct { + header + + returnCodes []byte +} + +var _ Message = (*SubackMessage)(nil) + +// NewSubackMessage creates a new SUBACK message. +func NewSubackMessage() *SubackMessage { + msg := &SubackMessage{} + msg.SetType(SUBACK) + + return msg +} + +// String returns a string representation of the message. +func (this SubackMessage) String() string { + return fmt.Sprintf("%s\nPacket ID: %d\nReturn Codes: %v\n", this.header, this.packetId, this.returnCodes) +} + +// ReturnCodes returns the list of QoS returns from the subscriptions sent in the SUBSCRIBE message. +func (this *SubackMessage) ReturnCodes() []byte { + return this.returnCodes +} + +// AddReturnCodes sets the list of QoS returns from the subscriptions sent in the SUBSCRIBE message. +// An error is returned if any of the QoS values are not valid. +func (this *SubackMessage) AddReturnCodes(ret []byte) error { + for _, c := range ret { + if c != QosAtMostOnce && c != QosAtLeastOnce && c != QosExactlyOnce && c != QosFailure { + return fmt.Errorf("suback/AddReturnCode: Invalid return code %d. Must be 0, 1, 2, 0x80.", c) + } + + this.returnCodes = append(this.returnCodes, c) + } + + return nil +} + +// AddReturnCode adds a single QoS return value. +func (this *SubackMessage) AddReturnCode(ret byte) error { + return this.AddReturnCodes([]byte{ret}) +} + +func (this *SubackMessage) Len() int { + ml := this.msglen() + + if err := this.SetRemainingLength(int32(ml)); err != nil { + return 0 + } + + return this.header.msglen() + ml +} + +func (this *SubackMessage) Decode(src []byte) (int, error) { + total := 0 + + hn, err := this.header.decode(src[total:]) + total += hn + if err != nil { + return total, err + } + + this.packetId = binary.BigEndian.Uint16(src[total:]) + total += 2 + + l := int(this.remlen) - (total - hn) + this.returnCodes = src[total : total+l] + total += len(this.returnCodes) + + for i, code := range this.returnCodes { + if code != 0x00 && code != 0x01 && code != 0x02 && code != 0x80 { + return total, fmt.Errorf("suback/Decode: Invalid return code %d for topic %d", code, i) + } + } + + return total, nil +} + +func (this *SubackMessage) Encode(dst []byte) (int, error) { + for i, code := range this.returnCodes { + if code != 0x00 && code != 0x01 && code != 0x02 && code != 0x80 { + return 0, fmt.Errorf("suback/Encode: Invalid return code %d for topic %d", code, i) + } + } + + hl := this.header.msglen() + ml := this.msglen() + + if len(dst) < hl+ml { + return 0, fmt.Errorf("suback/Encode: Insufficient buffer size. Expecting %d, got %d.", hl+ml, len(dst)) + } + + if err := this.SetRemainingLength(int32(ml)); err != nil { + return 0, err + } + + total := 0 + + n, err := this.header.encode(dst[total:]) + total += n + if err != nil { + return total, err + } + + binary.BigEndian.PutUint16(dst[total:], this.packetId) + total += 2 + + copy(dst[total:], this.returnCodes) + total += len(this.returnCodes) + + return total, nil +} + +func (this *SubackMessage) msglen() int { + return 2 + len(this.returnCodes) +} diff --git a/suback_test.go b/suback_test.go new file mode 100644 index 0000000..3a7811c --- /dev/null +++ b/suback_test.go @@ -0,0 +1,134 @@ +// Copyright (c) 2014 Dataence, LLC. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package message + +import ( + "testing" + + "github.com/dataence/assert" +) + +func TestSubackMessageFields(t *testing.T) { + msg := NewSubackMessage() + + msg.SetPacketId(100) + assert.Equal(t, true, 100, msg.PacketId(), "Error setting packet ID.") + + msg.AddReturnCode(1) + assert.Equal(t, true, 1, len(msg.ReturnCodes()), "Error adding return code.") + + err := msg.AddReturnCode(0x90) + assert.Error(t, true, err) +} + +func TestSubackMessageDecode(t *testing.T) { + msgBytes := []byte{ + byte(SUBACK << 4), + 6, + 0, // packet ID MSB (0) + 7, // packet ID LSB (7) + 0, // return code 1 + 1, // return code 2 + 2, // return code 3 + 0x80, // return code 4 + } + + msg := NewSubackMessage() + n, err := msg.Decode(msgBytes) + + assert.NoError(t, true, err, "Error decoding message.") + assert.Equal(t, true, len(msgBytes), n, "Error decoding message.") + assert.Equal(t, true, SUBACK, msg.Type(), "Error decoding message.") + assert.Equal(t, true, 4, len(msg.ReturnCodes()), "Error adding return code.") +} + +// test with wrong return code +func TestSubackMessageDecode2(t *testing.T) { + msgBytes := []byte{ + byte(SUBACK << 4), + 6, + 0, // packet ID MSB (0) + 7, // packet ID LSB (7) + 0, // return code 1 + 1, // return code 2 + 2, // return code 3 + 0x81, // return code 4 + } + + msg := NewSubackMessage() + _, err := msg.Decode(msgBytes) + + assert.Error(t, true, err) +} + +func TestSubackMessageEncode(t *testing.T) { + msgBytes := []byte{ + byte(SUBACK << 4), + 6, + 0, // packet ID MSB (0) + 7, // packet ID LSB (7) + 0, // return code 1 + 1, // return code 2 + 2, // return code 3 + 0x80, // return code 4 + } + + msg := NewSubackMessage() + msg.SetPacketId(7) + msg.AddReturnCode(0) + msg.AddReturnCode(1) + msg.AddReturnCode(2) + msg.AddReturnCode(0x80) + + dst := make([]byte, 10) + n, err := msg.Encode(dst) + + assert.NoError(t, true, err, "Error decoding message.") + assert.Equal(t, true, len(msgBytes), n, "Error decoding message.") + assert.Equal(t, true, msgBytes, dst[:n], "Error decoding message.") +} + +// test to ensure encoding and decoding are the same +// decode, encode, and decode again +func TestSubackDecodeEncodeEquiv(t *testing.T) { + msgBytes := []byte{ + byte(SUBACK << 4), + 6, + 0, // packet ID MSB (0) + 7, // packet ID LSB (7) + 0, // return code 1 + 1, // return code 2 + 2, // return code 3 + 0x80, // return code 4 + } + + msg := NewSubackMessage() + n, err := msg.Decode(msgBytes) + + assert.NoError(t, true, err, "Error decoding message.") + assert.Equal(t, true, len(msgBytes), n, "Error decoding message.") + + dst := make([]byte, 100) + n2, err := msg.Encode(dst) + + assert.NoError(t, true, err, "Error decoding message.") + assert.Equal(t, true, len(msgBytes), n2, "Error decoding message.") + assert.Equal(t, true, msgBytes, dst[:n2], "Error decoding message.") + + n3, err := msg.Decode(dst) + + assert.NoError(t, true, err, "Error decoding message.") + assert.Equal(t, true, len(msgBytes), n3, "Error decoding message.") +} diff --git a/subscribe.go b/subscribe.go new file mode 100644 index 0000000..a5099a6 --- /dev/null +++ b/subscribe.go @@ -0,0 +1,224 @@ +// Copyright (c) 2014 Dataence, LLC. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package message + +import ( + "bytes" + "encoding/binary" + "fmt" + "sync/atomic" +) + +// The SUBSCRIBE Packet is sent from the Client to the Server to create one or more +// Subscriptions. Each Subscription registers a Client’s interest in one or more +// Topics. The Server sends PUBLISH Packets to the Client in order to forward +// Application Messages that were published to Topics that match these Subscriptions. +// The SUBSCRIBE Packet also specifies (for each Subscription) the maximum QoS with +// which the Server can send Application Messages to the Client. +type SubscribeMessage struct { + header + + topics [][]byte + qos []byte +} + +var _ Message = (*SubscribeMessage)(nil) + +// NewSubscribeMessage creates a new SUBSCRIBE message. +func NewSubscribeMessage() *SubscribeMessage { + msg := &SubscribeMessage{} + msg.SetType(SUBSCRIBE) + + return msg +} + +// Topics returns a list of topics sent by the Client. +func (this *SubscribeMessage) Topics() [][]byte { + return this.topics +} + +// AddTopic adds a single topic to the message, along with the corresponding QoS. +// An error is returned if QoS is invalid. +func (this *SubscribeMessage) AddTopic(topic []byte, qos byte) error { + if !ValidQos(qos) { + return fmt.Errorf("Invalid QoS %d", qos) + } + + var i int + var t []byte + var found bool + + for i, t = range this.topics { + if bytes.Equal(t, topic) { + found = true + break + } + } + + if found { + this.qos[i] = qos + return nil + } + + this.topics = append(this.topics, topic) + this.qos = append(this.qos, qos) + + return nil +} + +// RemoveTopic removes a single topic from the list of existing ones in the message. +// If topic does not exist it just does nothing. +func (this *SubscribeMessage) RemoveTopic(topic []byte) { + var i int + var t []byte + var found bool + + for i, t = range this.topics { + if bytes.Equal(t, topic) { + found = true + break + } + } + + if found { + this.topics = append(this.topics[:i], this.topics[i+1:]...) + this.qos = append(this.qos[:i], this.qos[i+1:]...) + } +} + +// TopicExists checks to see if a topic exists in the list. +func (this *SubscribeMessage) TopicExists(topic []byte) bool { + for _, t := range this.topics { + if bytes.Equal(t, topic) { + return true + } + } + + return false +} + +// TopicQos returns the QoS level of a topic. If topic does not exist, QosFailure +// is returned. +func (this *SubscribeMessage) TopicQos(topic []byte) byte { + for i, t := range this.topics { + if bytes.Equal(t, topic) { + return this.qos[i] + } + } + + return QosFailure +} + +// Qos returns the list of QoS current in the message. +func (this *SubscribeMessage) Qos() []byte { + return this.qos +} + +func (this *SubscribeMessage) Len() int { + ml := this.msglen() + + if err := this.SetRemainingLength(int32(ml)); err != nil { + return 0 + } + + return this.header.msglen() + ml +} + +func (this *SubscribeMessage) Decode(src []byte) (int, error) { + total := 0 + + hn, err := this.header.decode(src[total:]) + total += hn + if err != nil { + return total, err + } + + this.packetId = binary.BigEndian.Uint16(src[total:]) + total += 2 + + remlen := int(this.remlen) - (total - hn) + for remlen > 0 { + t, n, err := readLPBytes(src[total:]) + total += n + if err != nil { + return total, err + } + + this.topics = append(this.topics, t) + + this.qos = append(this.qos, src[total]) + total++ + + remlen = remlen - n - 1 + } + + if len(this.topics) == 0 { + return 0, fmt.Errorf("subscribe/Decode: Empty topic list") + } + + return total, nil +} + +func (this *SubscribeMessage) Encode(dst []byte) (int, error) { + hl := this.header.msglen() + ml := this.msglen() + + if len(dst) < hl+ml { + return 0, fmt.Errorf("subscribe/Encode: Insufficient buffer size. Expecting %d, got %d.", hl+ml, len(dst)) + } + + if err := this.SetRemainingLength(int32(ml)); err != nil { + return 0, err + } + + total := 0 + + n, err := this.header.encode(dst[total:]) + total += n + if err != nil { + return total, err + } + + if this.packetId == 0 { + this.packetId = uint16(atomic.AddUint64(&gPacketId, 1) & 0xffff) + } + + binary.BigEndian.PutUint16(dst[total:], this.packetId) + total += 2 + + for i, t := range this.topics { + n, err := writeLPBytes(dst[total:], t) + total += n + if err != nil { + return total, err + } + + dst[total] = this.qos[i] + total++ + } + + return total, nil +} + +func (this *SubscribeMessage) msglen() int { + // packet ID + total := 2 + + for _, t := range this.topics { + total += 2 + len(t) + 1 + } + + return total +} diff --git a/subscribe_test.go b/subscribe_test.go new file mode 100644 index 0000000..3accd98 --- /dev/null +++ b/subscribe_test.go @@ -0,0 +1,161 @@ +// Copyright (c) 2014 Dataence, LLC. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package message + +import ( + "testing" + + "github.com/dataence/assert" +) + +func TestSubscribeMessageFields(t *testing.T) { + msg := NewSubscribeMessage() + + msg.SetPacketId(100) + assert.Equal(t, true, 100, msg.PacketId(), "Error setting packet ID.") + + msg.AddTopic([]byte("/a/b/#/c"), 1) + assert.Equal(t, true, 1, len(msg.Topics()), "Error adding topic.") + + assert.False(t, true, msg.TopicExists([]byte("a/b")), "Topic should not exist.") + + msg.RemoveTopic([]byte("/a/b/#/c")) + assert.False(t, true, msg.TopicExists([]byte("/a/b/#/c")), "Topic should not exist.") +} + +func TestSubscribeMessageDecode(t *testing.T) { + msgBytes := []byte{ + byte(SUBSCRIBE<<4) | 2, + 36, + 0, // packet ID MSB (0) + 7, // packet ID LSB (7) + 0, // topic name MSB (0) + 7, // topic name LSB (7) + 's', 'u', 'r', 'g', 'e', 'm', 'q', + 0, // QoS + 0, // topic name MSB (0) + 8, // topic name LSB (8) + '/', 'a', '/', 'b', '/', '#', '/', 'c', + 1, // QoS + 0, // topic name MSB (0) + 10, // topic name LSB (10) + '/', 'a', '/', 'b', '/', '#', '/', 'c', 'd', 'd', + 2, // QoS + } + + msg := NewSubscribeMessage() + n, err := msg.Decode(msgBytes) + + assert.NoError(t, true, err, "Error decoding message.") + assert.Equal(t, true, len(msgBytes), n, "Error decoding message.") + assert.Equal(t, true, SUBSCRIBE, msg.Type(), "Error decoding message.") + assert.Equal(t, true, 3, len(msg.Topics()), "Error decoding topics.") + assert.True(t, true, msg.TopicExists([]byte("surgemq")), "Topic 'surgemq' should exist.") + assert.Equal(t, true, 0, msg.TopicQos([]byte("surgemq")), "Incorrect topic qos.") + assert.True(t, true, msg.TopicExists([]byte("/a/b/#/c")), "Topic '/a/b/#/c' should exist.") + assert.Equal(t, true, 1, msg.TopicQos([]byte("/a/b/#/c")), "Incorrect topic qos.") + assert.True(t, true, msg.TopicExists([]byte("/a/b/#/cdd")), "Topic '/a/b/#/c' should exist.") + assert.Equal(t, true, 2, msg.TopicQos([]byte("/a/b/#/cdd")), "Incorrect topic qos.") +} + +// test empty topic list +func TestSubscribeMessageDecode2(t *testing.T) { + msgBytes := []byte{ + byte(SUBSCRIBE<<4) | 2, + 2, + 0, // packet ID MSB (0) + 7, // packet ID LSB (7) + } + + msg := NewSubscribeMessage() + _, err := msg.Decode(msgBytes) + + assert.Error(t, true, err) +} + +func TestSubscribeMessageEncode(t *testing.T) { + msgBytes := []byte{ + byte(SUBSCRIBE<<4) | 2, + 36, + 0, // packet ID MSB (0) + 7, // packet ID LSB (7) + 0, // topic name MSB (0) + 7, // topic name LSB (7) + 's', 'u', 'r', 'g', 'e', 'm', 'q', + 0, // QoS + 0, // topic name MSB (0) + 8, // topic name LSB (8) + '/', 'a', '/', 'b', '/', '#', '/', 'c', + 1, // QoS + 0, // topic name MSB (0) + 10, // topic name LSB (10) + '/', 'a', '/', 'b', '/', '#', '/', 'c', 'd', 'd', + 2, // QoS + } + + msg := NewSubscribeMessage() + msg.SetPacketId(7) + msg.AddTopic([]byte("surgemq"), 0) + msg.AddTopic([]byte("/a/b/#/c"), 1) + msg.AddTopic([]byte("/a/b/#/cdd"), 2) + + dst := make([]byte, 100) + n, err := msg.Encode(dst) + + assert.NoError(t, true, err, "Error decoding message.") + assert.Equal(t, true, len(msgBytes), n, "Error decoding message.") + assert.Equal(t, true, msgBytes, dst[:n], "Error decoding message.") +} + +// test to ensure encoding and decoding are the same +// decode, encode, and decode again +func TestSubscribeDecodeEncodeEquiv(t *testing.T) { + msgBytes := []byte{ + byte(SUBSCRIBE<<4) | 2, + 36, + 0, // packet ID MSB (0) + 7, // packet ID LSB (7) + 0, // topic name MSB (0) + 7, // topic name LSB (7) + 's', 'u', 'r', 'g', 'e', 'm', 'q', + 0, // QoS + 0, // topic name MSB (0) + 8, // topic name LSB (8) + '/', 'a', '/', 'b', '/', '#', '/', 'c', + 1, // QoS + 0, // topic name MSB (0) + 10, // topic name LSB (10) + '/', 'a', '/', 'b', '/', '#', '/', 'c', 'd', 'd', + 2, // QoS + } + + msg := NewSubscribeMessage() + n, err := msg.Decode(msgBytes) + + assert.NoError(t, true, err, "Error decoding message.") + assert.Equal(t, true, len(msgBytes), n, "Error decoding message.") + + dst := make([]byte, 100) + n2, err := msg.Encode(dst) + + assert.NoError(t, true, err, "Error decoding message.") + assert.Equal(t, true, len(msgBytes), n2, "Error decoding message.") + assert.Equal(t, true, msgBytes, dst[:n2], "Error decoding message.") + + n3, err := msg.Decode(dst) + + assert.NoError(t, true, err, "Error decoding message.") + assert.Equal(t, true, len(msgBytes), n3, "Error decoding message.") +} diff --git a/unsuback.go b/unsuback.go new file mode 100644 index 0000000..4fb5be3 --- /dev/null +++ b/unsuback.go @@ -0,0 +1,31 @@ +// Copyright (c) 2014 Dataence, LLC. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package message + +// The UNSUBACK Packet is sent by the Server to the Client to confirm receipt of an +// UNSUBSCRIBE Packet. +type UnsubackMessage struct { + PubackMessage +} + +var _ Message = (*UnsubackMessage)(nil) + +// NewUnsubackMessage creates a new UNSUBACK message. +func NewUnsubackMessage() *UnsubackMessage { + msg := &UnsubackMessage{} + msg.SetType(UNSUBACK) + + return msg +} diff --git a/unsuback_test.go b/unsuback_test.go new file mode 100644 index 0000000..00ef11f --- /dev/null +++ b/unsuback_test.go @@ -0,0 +1,108 @@ +// Copyright (c) 2014 Dataence, LLC. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package message + +import ( + "testing" + + "github.com/dataence/assert" +) + +func TestUnsubackMessageFields(t *testing.T) { + msg := NewUnsubackMessage() + + msg.SetPacketId(100) + + assert.Equal(t, true, 100, msg.PacketId()) +} + +func TestUnsubackMessageDecode(t *testing.T) { + msgBytes := []byte{ + byte(UNSUBACK << 4), + 2, + 0, // packet ID MSB (0) + 7, // packet ID LSB (7) + } + + msg := NewUnsubackMessage() + n, err := msg.Decode(msgBytes) + + assert.NoError(t, true, err, "Error decoding message.") + assert.Equal(t, true, len(msgBytes), n, "Error decoding message.") + assert.Equal(t, true, UNSUBACK, msg.Type(), "Error decoding message.") + assert.Equal(t, true, 7, msg.PacketId(), "Error decoding message.") +} + +// test insufficient bytes +func TestUnsubackMessageDecode2(t *testing.T) { + msgBytes := []byte{ + byte(UNSUBACK << 4), + 2, + 7, // packet ID LSB (7) + } + + msg := NewUnsubackMessage() + _, err := msg.Decode(msgBytes) + + assert.Error(t, true, err) +} + +func TestUnsubackMessageEncode(t *testing.T) { + msgBytes := []byte{ + byte(UNSUBACK << 4), + 2, + 0, // packet ID MSB (0) + 7, // packet ID LSB (7) + } + + msg := NewUnsubackMessage() + msg.SetPacketId(7) + + dst := make([]byte, 10) + n, err := msg.Encode(dst) + + assert.NoError(t, true, err, "Error decoding message.") + assert.Equal(t, true, len(msgBytes), n, "Error decoding message.") + assert.Equal(t, true, msgBytes, dst[:n], "Error decoding message.") +} + +// test to ensure encoding and decoding are the same +// decode, encode, and decode again +func TestUnsubackDecodeEncodeEquiv(t *testing.T) { + msgBytes := []byte{ + byte(UNSUBACK << 4), + 2, + 0, // packet ID MSB (0) + 7, // packet ID LSB (7) + } + + msg := NewUnsubackMessage() + n, err := msg.Decode(msgBytes) + + assert.NoError(t, true, err, "Error decoding message.") + assert.Equal(t, true, len(msgBytes), n, "Error decoding message.") + + dst := make([]byte, 100) + n2, err := msg.Encode(dst) + + assert.NoError(t, true, err, "Error decoding message.") + assert.Equal(t, true, len(msgBytes), n2, "Error decoding message.") + assert.Equal(t, true, msgBytes, dst[:n2], "Error decoding message.") + + n3, err := msg.Decode(dst) + + assert.NoError(t, true, err, "Error decoding message.") + assert.Equal(t, true, len(msgBytes), n3, "Error decoding message.") +} diff --git a/unsubscribe.go b/unsubscribe.go new file mode 100644 index 0000000..312d5f2 --- /dev/null +++ b/unsubscribe.go @@ -0,0 +1,181 @@ +// Copyright (c) 2014 Dataence, LLC. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package message + +import ( + "bytes" + "encoding/binary" + "fmt" + "sync/atomic" +) + +// An UNSUBSCRIBE Packet is sent by the Client to the Server, to unsubscribe from topics. +type UnsubscribeMessage struct { + header + + topics [][]byte +} + +var _ Message = (*UnsubscribeMessage)(nil) + +// NewUnsubscribeMessage creates a new UNSUBSCRIBE message. +func NewUnsubscribeMessage() *UnsubscribeMessage { + msg := &UnsubscribeMessage{} + msg.SetType(UNSUBSCRIBE) + + return msg +} + +// Topics returns a list of topics sent by the Client. +func (this *UnsubscribeMessage) Topics() [][]byte { + return this.topics +} + +// AddTopic adds a single topic to the message. +func (this *UnsubscribeMessage) AddTopic(topic []byte) { + if this.TopicExists(topic) { + return + } + + this.topics = append(this.topics, topic) +} + +// RemoveTopic removes a single topic from the list of existing ones in the message. +// If topic does not exist it just does nothing. +func (this *UnsubscribeMessage) RemoveTopic(topic []byte) { + var i int + var t []byte + var found bool + + for i, t = range this.topics { + if bytes.Equal(t, topic) { + found = true + break + } + } + + if found { + this.topics = append(this.topics[:i], this.topics[i+1:]...) + } +} + +// TopicExists checks to see if a topic exists in the list. +func (this *UnsubscribeMessage) TopicExists(topic []byte) bool { + for _, t := range this.topics { + if bytes.Equal(t, topic) { + return true + } + } + + return false +} + +func (this *UnsubscribeMessage) Len() int { + ml := this.msglen() + + if err := this.SetRemainingLength(int32(ml)); err != nil { + return 0 + } + + return this.header.msglen() + ml +} + +// Decode reads from the io.Reader parameter until a full message is decoded, or +// when io.Reader returns EOF or error. The first return value is the number of +// bytes read from io.Reader. The second is error if Decode encounters any problems. +func (this *UnsubscribeMessage) Decode(src []byte) (int, error) { + total := 0 + + hn, err := this.header.decode(src[total:]) + total += hn + if err != nil { + return total, err + } + + this.packetId = binary.BigEndian.Uint16(src[total:]) + total += 2 + + remlen := int(this.remlen) - (total - hn) + for remlen > 0 { + t, n, err := readLPBytes(src[total:]) + total += n + if err != nil { + return total, err + } + + this.topics = append(this.topics, t) + remlen = remlen - n - 1 + } + + if len(this.topics) == 0 { + return 0, fmt.Errorf("unsubscribe/Decode: Empty topic list") + } + + return total, nil +} + +// Encode returns an io.Reader in which the encoded bytes can be read. The second +// return value is the number of bytes encoded, so the caller knows how many bytes +// there will be. If Encode returns an error, then the first two return values +// should be considered invalid. +// Any changes to the message after Encode() is called will invalidate the io.Reader. +func (this *UnsubscribeMessage) Encode(dst []byte) (int, error) { + hl := this.header.msglen() + ml := this.msglen() + + if len(dst) < hl+ml { + return 0, fmt.Errorf("unsubscribe/Encode: Insufficient buffer size. Expecting %d, got %d.", hl+ml, len(dst)) + } + + if err := this.SetRemainingLength(int32(ml)); err != nil { + return 0, err + } + + total := 0 + + n, err := this.header.encode(dst[total:]) + total += n + if err != nil { + return total, err + } + + if this.packetId == 0 { + this.packetId = uint16(atomic.AddUint64(&gPacketId, 1) & 0xffff) + } + + binary.BigEndian.PutUint16(dst[total:], this.packetId) + total += 2 + + for _, t := range this.topics { + n, err := writeLPBytes(dst[total:], t) + total += n + if err != nil { + return total, err + } + } + + return total, nil +} + +func (this *UnsubscribeMessage) msglen() int { + // packet ID + total := 2 + + for _, t := range this.topics { + total += 2 + len(t) + } + + return total +} diff --git a/unsubscribe_test.go b/unsubscribe_test.go new file mode 100644 index 0000000..a620e4d --- /dev/null +++ b/unsubscribe_test.go @@ -0,0 +1,155 @@ +// Copyright (c) 2014 Dataence, LLC. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package message + +import ( + "testing" + + "github.com/dataence/assert" +) + +func TestUnsubscribeMessageFields(t *testing.T) { + msg := NewUnsubscribeMessage() + + msg.SetPacketId(100) + assert.Equal(t, true, 100, msg.PacketId(), "Error setting packet ID.") + + msg.AddTopic([]byte("/a/b/#/c")) + assert.Equal(t, true, 1, len(msg.Topics()), "Error adding topic.") + + msg.AddTopic([]byte("/a/b/#/c")) + assert.Equal(t, true, 1, len(msg.Topics()), "Error adding duplicate topic.") + + msg.RemoveTopic([]byte("/a/b/#/c")) + assert.False(t, true, msg.TopicExists([]byte("/a/b/#/c")), "Topic should not exist.") + + assert.False(t, true, msg.TopicExists([]byte("a/b")), "Topic should not exist.") + + msg.RemoveTopic([]byte("/a/b/#/c")) + assert.False(t, true, msg.TopicExists([]byte("/a/b/#/c")), "Topic should not exist.") +} + +func TestUnsubscribeMessageDecode(t *testing.T) { + msgBytes := []byte{ + byte(UNSUBSCRIBE<<4) | 2, + 33, + 0, // packet ID MSB (0) + 7, // packet ID LSB (7) + 0, // topic name MSB (0) + 7, // topic name LSB (7) + 's', 'u', 'r', 'g', 'e', 'm', 'q', + 0, // topic name MSB (0) + 8, // topic name LSB (8) + '/', 'a', '/', 'b', '/', '#', '/', 'c', + 0, // topic name MSB (0) + 10, // topic name LSB (10) + '/', 'a', '/', 'b', '/', '#', '/', 'c', 'd', 'd', + } + + msg := NewUnsubscribeMessage() + n, err := msg.Decode(msgBytes) + + assert.NoError(t, true, err, "Error decoding message.") + assert.Equal(t, true, len(msgBytes), n, "Error decoding message.") + assert.Equal(t, true, UNSUBSCRIBE, msg.Type(), "Error decoding message.") + assert.Equal(t, true, 3, len(msg.Topics()), "Error decoding topics.") + assert.True(t, true, msg.TopicExists([]byte("surgemq")), "Topic 'surgemq' should exist.") + assert.True(t, true, msg.TopicExists([]byte("/a/b/#/c")), "Topic '/a/b/#/c' should exist.") + assert.True(t, true, msg.TopicExists([]byte("/a/b/#/cdd")), "Topic '/a/b/#/c' should exist.") +} + +// test empty topic list +func TestUnsubscribeMessageDecode2(t *testing.T) { + msgBytes := []byte{ + byte(UNSUBSCRIBE<<4) | 2, + 2, + 0, // packet ID MSB (0) + 7, // packet ID LSB (7) + } + + msg := NewUnsubscribeMessage() + _, err := msg.Decode(msgBytes) + + assert.Error(t, true, err) +} + +func TestUnsubscribeMessageEncode(t *testing.T) { + msgBytes := []byte{ + byte(UNSUBSCRIBE<<4) | 2, + 33, + 0, // packet ID MSB (0) + 7, // packet ID LSB (7) + 0, // topic name MSB (0) + 7, // topic name LSB (7) + 's', 'u', 'r', 'g', 'e', 'm', 'q', + 0, // topic name MSB (0) + 8, // topic name LSB (8) + '/', 'a', '/', 'b', '/', '#', '/', 'c', + 0, // topic name MSB (0) + 10, // topic name LSB (10) + '/', 'a', '/', 'b', '/', '#', '/', 'c', 'd', 'd', + } + + msg := NewUnsubscribeMessage() + msg.SetPacketId(7) + msg.AddTopic([]byte("surgemq")) + msg.AddTopic([]byte("/a/b/#/c")) + msg.AddTopic([]byte("/a/b/#/cdd")) + + dst := make([]byte, 100) + n, err := msg.Encode(dst) + + assert.NoError(t, true, err, "Error decoding message.") + assert.Equal(t, true, len(msgBytes), n, "Error decoding message.") + assert.Equal(t, true, msgBytes, dst[:n], "Error decoding message.") +} + +// test to ensure encoding and decoding are the same +// decode, encode, and decode again +func TestUnsubscribeDecodeEncodeEquiv(t *testing.T) { + msgBytes := []byte{ + byte(UNSUBSCRIBE<<4) | 2, + 33, + 0, // packet ID MSB (0) + 7, // packet ID LSB (7) + 0, // topic name MSB (0) + 7, // topic name LSB (7) + 's', 'u', 'r', 'g', 'e', 'm', 'q', + 0, // topic name MSB (0) + 8, // topic name LSB (8) + '/', 'a', '/', 'b', '/', '#', '/', 'c', + 0, // topic name MSB (0) + 10, // topic name LSB (10) + '/', 'a', '/', 'b', '/', '#', '/', 'c', 'd', 'd', + } + + msg := NewUnsubscribeMessage() + n, err := msg.Decode(msgBytes) + + assert.NoError(t, true, err, "Error decoding message.") + assert.Equal(t, true, len(msgBytes), n, "Error decoding message.") + + dst := make([]byte, 100) + n2, err := msg.Encode(dst) + + assert.NoError(t, true, err, "Error decoding message.") + assert.Equal(t, true, len(msgBytes), n2, "Error decoding message.") + assert.Equal(t, true, msgBytes, dst[:n2], "Error decoding message.") + + n3, err := msg.Decode(dst) + + assert.NoError(t, true, err, "Error decoding message.") + assert.Equal(t, true, len(msgBytes), n3, "Error decoding message.") +} diff --git a/utils.go b/utils.go new file mode 100644 index 0000000..8098acd --- /dev/null +++ b/utils.go @@ -0,0 +1,61 @@ +// Copyright (c) 2014 Dataence, LLC. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package message + +import ( + "encoding/binary" + "fmt" +) + +// Read length prefixed bytes +func readLPBytes(buf []byte) ([]byte, int, error) { + if len(buf) < 2 { + return nil, 0, fmt.Errorf("utils/readLPBytes: Insufficient buffer size. Expecting %d, got %d.", 2, len(buf)) + } + + n, total := 0, 0 + + n = int(binary.BigEndian.Uint16(buf)) + total += 2 + + if len(buf) < n { + return nil, total, fmt.Errorf("utils/readLPBytes: Insufficient buffer size. Expecting %d, got %d.", n, len(buf)) + } + + total += n + + return buf[2:total], total, nil +} + +// Write length prefixed bytes +func writeLPBytes(buf []byte, b []byte) (int, error) { + total, n := 0, len(b) + + if n > int(maxLPString) { + return 0, fmt.Errorf("utils/writeLPBytes: Length (%d) greater than %d bytes.", n, maxLPString) + } + + if len(buf) < 2+n { + return 0, fmt.Errorf("utils/writeLPBytes: Insufficient buffer size. Expecting %d, got %d.", 2+n, len(buf)) + } + + binary.BigEndian.PutUint16(buf, uint16(n)) + total += 2 + + copy(buf[total:], b) + total += n + + return total, nil +}