Skip to content

Commit

Permalink
WIP use the subscription identifier in publish
Browse files Browse the repository at this point in the history
  • Loading branch information
halfgaar committed Nov 9, 2024
1 parent 4a3b6aa commit c4c7c77
Show file tree
Hide file tree
Showing 21 changed files with 70 additions and 33 deletions.
2 changes: 1 addition & 1 deletion FlashMQTests/plugintests.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -543,7 +543,7 @@ void MainTests::testFailedAsyncClientCrashOnSession()
FlashMQTestClient &second_client = clients.back();

Publish pub("sdf", "wer", 2);
MqttPacket pubPack(second_client.getClient()->getProtocolVersion(), pub);
MqttPacket pubPack(second_client.getClient()->getProtocolVersion(), pub, 0);
if (pub.qos > 0)
pubPack.setPacketId(3);
second_client.getClient()->writeMqttPacketAndBlameThisClient(pubPack);
Expand Down
8 changes: 4 additions & 4 deletions FlashMQTests/tst_maintests.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -795,7 +795,7 @@ void MainTests::testPacketInt16Parse()
for (const uint16_t id : tests)
{
Publish pub("hallo", "content", 1);
MqttPacket packet(ProtocolVersion::Mqtt311, pub);
MqttPacket packet(ProtocolVersion::Mqtt311, pub, 0);
packet.setPacketId(id);
packet.pos -= 2;
uint16_t idParsed = packet.readTwoBytesToUInt16();
Expand Down Expand Up @@ -978,9 +978,9 @@ void MainTests::testSavingSessions()

std::shared_ptr<Session> c1ses = c1->getSession();
c1.reset();
MqttPacket publishPacket(ProtocolVersion::Mqtt5, publish);
MqttPacket publishPacket(ProtocolVersion::Mqtt5, publish, 1); // TODO: subscription identifier, what to do here?
PublishCopyFactory fac(&publishPacket);
c1ses->writePacket(fac, 1, false);
c1ses->writePacket(fac, 1, false, publishPacket.getPublishData().subscriptionIdentifier); // TODO: subscription identifier? What to do / test?

FlashMQTempDir tmpdir;
auto dbpath = tmpdir.getPath() / "flashmqtests_sessions.db";
Expand Down Expand Up @@ -1102,7 +1102,7 @@ void MainTests::testParsePacketHelper(const std::string &topic, uint8_t from_qos
const std::string payloadOne = getSecureRandomString(len);
Publish pubOne(topic, payloadOne, from_qos);
pubOne.retain = retain;
MqttPacket stagingPacketOne(ProtocolVersion::Mqtt311, pubOne);
MqttPacket stagingPacketOne(ProtocolVersion::Mqtt311, pubOne, 0);
if (from_qos > 0)
stagingPacketOne.setPacketId(pack_id);
CirBuf stagingBufOne(1024);
Expand Down
5 changes: 3 additions & 2 deletions client.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -325,7 +325,8 @@ PacketDropReason Client::writeMqttPacket(const MqttPacket &packet)
return PacketDropReason::Success;
}

PacketDropReason Client::writeMqttPacketAndBlameThisClient(PublishCopyFactory &copyFactory, uint8_t max_qos, uint16_t packet_id, bool retain)
PacketDropReason Client::writeMqttPacketAndBlameThisClient(
PublishCopyFactory &copyFactory, uint8_t max_qos, uint16_t packet_id, bool retain, uint32_t subscriptionIdentifier)
{
uint16_t topic_alias = 0;
uint16_t topic_alias_next = 0;
Expand Down Expand Up @@ -363,7 +364,7 @@ PacketDropReason Client::writeMqttPacketAndBlameThisClient(PublishCopyFactory &c
}
}

MqttPacket *p = copyFactory.getOptimumPacket(max_qos, this->protocolVersion, topic_alias, skip_topic);
MqttPacket *p = copyFactory.getOptimumPacket(max_qos, this->protocolVersion, topic_alias, skip_topic, subscriptionIdentifier);

assert(static_cast<bool>(p->getQos()) == static_cast<bool>(max_qos));
assert(PublishCopyFactory::getPublishLayoutCompareKey(protocolVersion) == PublishCopyFactory::getPublishLayoutCompareKey(p->getProtocolVersion()));
Expand Down
2 changes: 1 addition & 1 deletion client.h
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,7 @@ class Client
void writePingResp();
void writeLoginPacket();
PacketDropReason writeMqttPacket(const MqttPacket &packet);
PacketDropReason writeMqttPacketAndBlameThisClient(PublishCopyFactory &copyFactory, uint8_t max_qos, uint16_t packet_id, bool retain);
PacketDropReason writeMqttPacketAndBlameThisClient(PublishCopyFactory &copyFactory, uint8_t max_qos, uint16_t packet_id, bool retain, uint32_t subscriptionIdentifier);
PacketDropReason writeMqttPacketAndBlameThisClient(const MqttPacket &packet);
void writeBufIntoFd();
DisconnectStage getDisconnectStage() const { return disconnectStage; }
Expand Down
2 changes: 1 addition & 1 deletion flashmqtestclient.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -243,7 +243,7 @@ void FlashMQTestClient::publish(Publish &pub)

const uint16_t packet_id = 77;

MqttPacket pubPack(client->getProtocolVersion(), pub);
MqttPacket pubPack(client->getProtocolVersion(), pub, 0);
if (pub.qos > 0)
pubPack.setPacketId(packet_id);
client->writeMqttPacketAndBlameThisClient(pubPack);
Expand Down
18 changes: 18 additions & 0 deletions mqtt5properties.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,11 @@ void Mqtt5PropertyBuilder::writeWildcardSubscriptionAvailable(uint8_t val)
writeUint8(Mqtt5Properties::WildcardSubscriptionAvailable, val);
}

void Mqtt5PropertyBuilder::writeSubscriptionIdentifier(uint32_t val)
{
writeVariableByteInt(Mqtt5Properties::SubscriptionIdentifier, val);
}

void Mqtt5PropertyBuilder::writeSubscriptionIdentifiersAvailable(uint8_t val)
{
writeUint8(Mqtt5Properties::SubscriptionIdentifierAvailable, val);
Expand Down Expand Up @@ -243,3 +248,16 @@ void Mqtt5PropertyBuilder::write2Str(Mqtt5Properties prop, const std::string &on
}
}

void Mqtt5PropertyBuilder::writeVariableByteInt(Mqtt5Properties prop, const uint32_t val)
{
const VariableByteInt x(val);

size_t pos = bytes.size();
const size_t newSize = pos + x.getLen() + 1;
bytes.resize(newSize);

bytes[pos++] = static_cast<uint8_t>(prop);

std::memcpy(&bytes[pos], x.data(), x.getLen());
}

2 changes: 2 additions & 0 deletions mqtt5properties.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ class Mqtt5PropertyBuilder
void writeUint8(Mqtt5Properties prop, const uint8_t x);
void writeStr(Mqtt5Properties prop, const std::string &str);
void write2Str(Mqtt5Properties prop, const std::string &one, const std::string &two);
void writeVariableByteInt(Mqtt5Properties prop, const unsigned int val);
public:
Mqtt5PropertyBuilder();

Expand All @@ -43,6 +44,7 @@ class Mqtt5PropertyBuilder
void writeAssignedClientId(const std::string &clientid);
void writeMaxTopicAliases(uint16_t val);
void writeWildcardSubscriptionAvailable(uint8_t val);
void writeSubscriptionIdentifier(uint32_t val);
void writeSubscriptionIdentifiersAvailable(uint8_t val);
void writeSharedSubscriptionAvailable(uint8_t val);
void writeContentType(const std::string &format);
Expand Down
8 changes: 5 additions & 3 deletions mqttpacket.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -107,8 +107,8 @@ MqttPacket::MqttPacket(const UnsubAck &unsubAck) :
calculateRemainingLength();
}

MqttPacket::MqttPacket(const ProtocolVersion protocolVersion, const Publish &_publish) :
MqttPacket(protocolVersion, _publish, _publish.qos, _publish.topicAlias, _publish.skipTopic)
MqttPacket::MqttPacket(const ProtocolVersion protocolVersion, const Publish &_publish, const uint32_t subscriptionIdentifier) :
MqttPacket(protocolVersion, _publish, _publish.qos, _publish.topicAlias, _publish.skipTopic, subscriptionIdentifier)
{

}
Expand All @@ -124,7 +124,8 @@ MqttPacket::MqttPacket(const ProtocolVersion protocolVersion, const Publish &_pu
* The extra parameters are for overriding certain properties of the publish, because the receiving client wants it differently. Use the other overload
* if you just want the publish object's data.
*/
MqttPacket::MqttPacket(const ProtocolVersion protocolVersion, const Publish &_publish, const uint8_t _qos, const uint16_t _topic_alias, const bool _skip_topic)
MqttPacket::MqttPacket(const ProtocolVersion protocolVersion, const Publish &_publish, const uint8_t _qos, const uint16_t _topic_alias,
const bool _skip_topic, const uint32_t subscriptionIdentifier)
{
if (_publish.topic.length() > 0xFFFF)
{
Expand Down Expand Up @@ -159,6 +160,7 @@ MqttPacket::MqttPacket(const ProtocolVersion protocolVersion, const Publish &_pu
this->publishData.contentType = _publish.contentType;
this->publishData.payloadUtf8 = _publish.payloadUtf8;
this->publishData.userProperties = _publish.userProperties;
this->publishData.subscriptionIdentifier = subscriptionIdentifier;

property_builder = this->publishData.getPropertyBuilder();
}
Expand Down
5 changes: 3 additions & 2 deletions mqttpacket.h
Original file line number Diff line number Diff line change
Expand Up @@ -116,8 +116,9 @@ class MqttPacket
MqttPacket(const ConnAck &connAck);
MqttPacket(const SubAck &subAck);
MqttPacket(const UnsubAck &unsubAck);
MqttPacket(const ProtocolVersion protocolVersion, const Publish &_publish);
MqttPacket(const ProtocolVersion protocolVersion, const Publish &_publish, const uint8_t _qos, const uint16_t _topic_alias, const bool _skip_topic);
MqttPacket(const ProtocolVersion protocolVersion, const Publish &_publish, const uint32_t subscriptionIdentifier);
MqttPacket(const ProtocolVersion protocolVersion, const Publish &_publish, const uint8_t _qos, const uint16_t _topic_alias,
const bool _skip_topic, const uint32_t subscriptionIdentifier);
MqttPacket(const PubResponse &pubAck);
MqttPacket(const Disconnect &disconnect);
MqttPacket(const Auth &auth);
Expand Down
13 changes: 7 additions & 6 deletions publishcopyfactory.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,8 @@ PublishCopyFactory::PublishCopyFactory(Publish *publish) :

}

MqttPacket *PublishCopyFactory::getOptimumPacket(const uint8_t max_qos, const ProtocolVersion protocolVersion, uint16_t topic_alias, bool skip_topic)
MqttPacket *PublishCopyFactory::getOptimumPacket(
const uint8_t max_qos, const ProtocolVersion protocolVersion, uint16_t topic_alias, bool skip_topic, int subscriptionIdentifier)
{
const uint8_t actualQos = getEffectiveQos(max_qos);

Expand All @@ -39,10 +40,10 @@ MqttPacket *PublishCopyFactory::getOptimumPacket(const uint8_t max_qos, const Pr
// The incoming topic alias is not relevant after initial conversion and it should not propagate.
assert(packet->getPublishData().topicAlias == 0);

// When the packet contains an alias specific to the receiver, we don't cache it.
if (protocolVersion >= ProtocolVersion::Mqtt5 && topic_alias > 0)
// When the packet contains an data specific to the receiver, we don't cache it.
if ((protocolVersion >= ProtocolVersion::Mqtt5 && topic_alias > 0) || subscriptionIdentifier > 0)
{
this->oneShotPacket.emplace(protocolVersion, packet->getPublishData(), actualQos, topic_alias, skip_topic);
this->oneShotPacket.emplace(protocolVersion, packet->getPublishData(), actualQos, topic_alias, skip_topic, subscriptionIdentifier);
return &*this->oneShotPacket;
}

Expand All @@ -60,7 +61,7 @@ MqttPacket *PublishCopyFactory::getOptimumPacket(const uint8_t max_qos, const Pr

if (!cachedPack)
{
cachedPack.emplace(protocolVersion, packet->getPublishData(), actualQos, 0, false);
cachedPack.emplace(protocolVersion, packet->getPublishData(), actualQos, 0, false, 0);
}

return &*cachedPack;
Expand All @@ -72,7 +73,7 @@ MqttPacket *PublishCopyFactory::getOptimumPacket(const uint8_t max_qos, const Pr
// The incoming topic alias is not relevant after initial conversion and it should not propagate.
assert(publish->topicAlias == 0);

this->oneShotPacket.emplace(protocolVersion, *publish, actualQos, topic_alias, skip_topic);
this->oneShotPacket.emplace(protocolVersion, *publish, actualQos, topic_alias, skip_topic, subscriptionIdentifier);
return &*this->oneShotPacket;
}

Expand Down
2 changes: 1 addition & 1 deletion publishcopyfactory.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ class PublishCopyFactory
PublishCopyFactory(const PublishCopyFactory &other) = delete;
PublishCopyFactory(PublishCopyFactory &&other) = delete;

MqttPacket *getOptimumPacket(const uint8_t max_qos, const ProtocolVersion protocolVersion, uint16_t topic_alias, bool skip_topic);
MqttPacket *getOptimumPacket(const uint8_t max_qos, const ProtocolVersion protocolVersion, uint16_t topic_alias, bool skip_topic, int subscriptionIdentifier);
uint8_t getEffectiveQos(uint8_t max_qos) const;
bool getEffectiveRetain(bool retainAsPublished) const;
const std::string &getTopic() const;
Expand Down
2 changes: 1 addition & 1 deletion retainedmessagesdb.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ void RetainedMessagesDB::saveData(const std::vector<RetainedMessage> &messages)
logger->logf(LOG_DEBUG, "Saving retained message for topic '%s' QoS %d, age %d seconds.", rm.publish.topic.c_str(), rm.publish.qos, rm.publish.getAge());

Publish pcopy(rm.publish);
MqttPacket pack(ProtocolVersion::Mqtt5, pcopy);
MqttPacket pack(ProtocolVersion::Mqtt5, pcopy, 0); // TODO: subscription identifier

// Dummy, to please the parser on reading.
if (pcopy.qos > 0)
Expand Down
8 changes: 4 additions & 4 deletions session.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ void Session::assignActiveConnection(const std::shared_ptr<Session> &thisSession
* @param retain. Keep MQTT-3.3.1-9 in mind: existing subscribers don't get retain=1 on packets.
* @param count. Reference value is updated. It's for statistics.
*/
PacketDropReason Session::writePacket(PublishCopyFactory &copyFactory, const uint8_t max_qos, bool retainAsPublished)
PacketDropReason Session::writePacket(PublishCopyFactory &copyFactory, const uint8_t max_qos, bool retainAsPublished, const uint32_t subscriptionIdentifier)
{
/*
* We want to do as little as possible before the ACL check, because it's code that's called
Expand Down Expand Up @@ -180,7 +180,7 @@ PacketDropReason Session::writePacket(PublishCopyFactory &copyFactory, const uin
pack_id = getNextPacketId();

if (!destroyOnDisconnect)
qosPacketQueue.queuePublish(copyFactory, pack_id, effectiveQos, effectiveRetain);
qosPacketQueue.queuePublish(copyFactory, pack_id, effectiveQos, effectiveRetain); // TODO: here subscription identifier
}

PacketDropReason return_value = PacketDropReason::ClientOffline;
Expand All @@ -190,7 +190,7 @@ PacketDropReason Session::writePacket(PublishCopyFactory &copyFactory, const uin
if (!c->isRetainedAvailable())
effectiveRetain = false;

return_value = c->writeMqttPacketAndBlameThisClient(copyFactory, effectiveQos, pack_id, effectiveRetain);
return_value = c->writeMqttPacketAndBlameThisClient(copyFactory, effectiveQos, pack_id, effectiveRetain, subscriptionIdentifier);
}

return return_value;
Expand Down Expand Up @@ -289,7 +289,7 @@ void Session::sendAllPendingQosData()
{
PublishCopyFactory fac(&p.first);
const bool retain = !c->isRetainedAvailable() ? false : p.first.retain;
c->writeMqttPacketAndBlameThisClient(fac, p.first.qos, p.second, retain);
c->writeMqttPacketAndBlameThisClient(fac, p.first.qos, p.second, retain, 0); // TODO: subscription identifiers
}

for(uint16_t id : copiedQoS2Ids)
Expand Down
2 changes: 1 addition & 1 deletion session.h
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ class Session
void assignActiveConnection(const std::shared_ptr<Client> &client);
void assignActiveConnection(const std::shared_ptr<Session> &thisSession, const std::shared_ptr<Client> &client,
uint16_t clientReceiveMax, uint32_t sessionExpiryInterval, bool clean_start);
PacketDropReason writePacket(PublishCopyFactory &copyFactory, const uint8_t max_qos, bool retainAsPublished);
PacketDropReason writePacket(PublishCopyFactory &copyFactory, const uint8_t max_qos, bool retainAsPublished, const uint32_t subscriptionIdentifier);
bool clearQosMessage(uint16_t packet_id, bool qosHandshakeEnds);
void sendAllPendingQosData();
bool hasActiveClient();
Expand Down
4 changes: 2 additions & 2 deletions sessionsandsubscriptionsdb.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -321,7 +321,7 @@ void SessionsAndSubscriptionsDB::saveData(const std::vector<std::shared_ptr<Sess

logger->logf(LOG_DEBUG, "Saving QoS %d message for topic '%s'.", pub.qos, pub.topic.c_str());

MqttPacket pack(ProtocolVersion::Mqtt5, pub);
MqttPacket pack(ProtocolVersion::Mqtt5, pub, 0); // TODO: subscription identifier
pack.setPacketId(p.getPacketId());
const uint32_t packSize = pack.getSizeIncludingNonPresentHeader();
cirbuf.reset();
Expand Down Expand Up @@ -368,7 +368,7 @@ void SessionsAndSubscriptionsDB::saveData(const std::vector<std::shared_ptr<Sess

if (hasWillThatShouldSurviveRestart)
{
MqttPacket willpacket(ProtocolVersion::Mqtt5, *will);
MqttPacket willpacket(ProtocolVersion::Mqtt5, *will, 0); // TODO subscription identifier

// Dummy, to please the parser on reading.
if (will->qos > 0)
Expand Down
6 changes: 3 additions & 3 deletions subscriptionstore.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ ReceivingSubscriber::ReceivingSubscriber(const std::weak_ptr<Session> &ses, uint
session(ses.lock()),
qos(qos),
retainAsPublished(retainAsPublished),
subscriptoinIdentifier(subscriptionIdentifier)
subscriptionIdentifier(subscriptionIdentifier)
{

}
Expand Down Expand Up @@ -672,7 +672,7 @@ void SubscriptionStore::queuePacketAtSubscribers(PublishCopyFactory &copyFactory

for(const ReceivingSubscriber &x : subscriberSessions)
{
x.session->writePacket(copyFactory, x.qos, x.retainAsPublished);
x.session->writePacket(copyFactory, x.qos, x.retainAsPublished, x.subscriptionIdentifier);
}
}

Expand Down Expand Up @@ -704,7 +704,7 @@ void SubscriptionStore::giveClientRetainedMessagesRecursively(std::vector<std::s
if (auth.aclCheck(publish, publish.payload) == AuthResult::success)
{
PublishCopyFactory copyFactory(&publish);
const PacketDropReason drop_reason = session->writePacket(copyFactory, max_qos, true);
const PacketDropReason drop_reason = session->writePacket(copyFactory, max_qos, true, 0); // TODO: subscription identifier. From args?

if (drop_reason == PacketDropReason::BufferFull || drop_reason == PacketDropReason::QoSTODOSomethingSomething)
{
Expand Down
2 changes: 1 addition & 1 deletion subscriptionstore.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ struct ReceivingSubscriber
const std::shared_ptr<Session> session;
const uint8_t qos;
const bool retainAsPublished;
const int subscriptoinIdentifier = 0;
const int subscriptionIdentifier = 0;

public:
ReceivingSubscriber(const std::weak_ptr<Session> &ses, uint8_t qos, bool retainAsPublished, const int subscriptionIdentifier);
Expand Down
3 changes: 3 additions & 0 deletions types.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,9 @@ std::optional<Mqtt5PropertyBuilder> Publish::getPropertyBuilder() const
if (topicAlias > 0)
non_optional(property_builder)->writeTopicAlias(topicAlias);

if (subscriptionIdentifier > 0)
non_optional(property_builder)->writeSubscriptionIdentifier(subscriptionIdentifier);

if (userProperties)
non_optional(property_builder)->writeUserProperties(*userProperties);

Expand Down
1 change: 1 addition & 0 deletions types.h
Original file line number Diff line number Diff line change
Expand Up @@ -236,6 +236,7 @@ class Publish
std::optional<std::string> correlationData;
std::optional<std::string> responseTopic;
std::optional<std::string> contentType;
uint32_t subscriptionIdentifier = 0;
std::shared_ptr<std::vector<std::pair<std::string, std::string>>> userProperties;

Publish() = default;
Expand Down
5 changes: 5 additions & 0 deletions variablebyteint.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,11 @@ See LICENSE for license details.
#include <cstring>
#include <stdexcept>

VariableByteInt::VariableByteInt(uint32_t val)
{
*this = val;
}

void VariableByteInt::readIntoBuf(CirBuf &buf) const
{
assert(len > 0);
Expand Down
3 changes: 3 additions & 0 deletions variablebyteint.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,9 @@ class VariableByteInt
uint8_t len = 0;

public:
VariableByteInt(uint32_t val);
VariableByteInt() = default;

void readIntoBuf(CirBuf &buf) const;
VariableByteInt &operator=(uint32_t x);
uint8_t getLen() const;
Expand Down

0 comments on commit c4c7c77

Please sign in to comment.