Skip to content

Commit

Permalink
WIP
Browse files Browse the repository at this point in the history
Signed-off-by: Roshan Khatri <rvkhatri@amazon.com>
  • Loading branch information
roshkhatri committed May 15, 2024
1 parent 93f8a19 commit fe09e22
Show file tree
Hide file tree
Showing 9 changed files with 218 additions and 15 deletions.
2 changes: 1 addition & 1 deletion src/cluster.h
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ void clusterUpdateMyselfHostname(void);
void clusterUpdateMyselfAnnouncedPorts(void);
void clusterUpdateMyselfHumanNodename(void);

void clusterPropagatePublish(robj *channel, robj *message, int sharded);
void clusterPropagatePublish(robj *channel, robj **message, int count, int sharded);

unsigned long getClusterConnectionsCount(void);
int isClusterHealthy(void);
Expand Down
114 changes: 110 additions & 4 deletions src/cluster_legacy.c
Original file line number Diff line number Diff line change
Expand Up @@ -2763,7 +2763,7 @@ int clusterProcessPacket(clusterLink *link) {
} else if (type == CLUSTERMSG_TYPE_FAIL) {
explen = sizeof(clusterMsg)-sizeof(union clusterMsgData);
explen += sizeof(clusterMsgDataFail);
} else if (type == CLUSTERMSG_TYPE_PUBLISH || type == CLUSTERMSG_TYPE_PUBLISHSHARD) {
} else if (type == CLUSTERMSG_TYPE_PUBLISH || type == CLUSTERMSG_TYPE_PUBLISHSHARD || type == CLUSTERMSG_TYPE_MPUBLISH ) {
explen = sizeof(clusterMsg)-sizeof(union clusterMsgData);
explen += sizeof(clusterMsgDataPublish) -
8 +
Expand Down Expand Up @@ -3197,6 +3197,60 @@ int clusterProcessPacket(clusterLink *link) {
decrRefCount(channel);
decrRefCount(message);
}
} else if (type == CLUSTERMSG_TYPE_MPUBLISH) {
if (!sender) return 1; /* We don't know that node. */

const uint8_t *src, *end;
robj *channel, *message;
uint32_t channel_len, message_len, len;
unsigned i, msg_count;

/* Don't bother creating useless objects if there are no
* Pub/Sub subscribers. */
if ((type == CLUSTERMSG_TYPE_MPUBLISH
&& serverPubsubSubscriptionCount() > 0)) {
channel_len = ntohl(hdr->data.publish.msg.channel_len);
message_len = ntohl(hdr->data.publish.msg.message_len);
/* Count messages */
src = hdr->data.publish.msg.bulk_data + channel_len;
end = src + message_len;
msg_count = 0;
while (src + 4 <= end) {
memcpy(&len, src, sizeof(len));
len = ntohl(len);
src += 4;
if (src + len > end) {
serverLog(LL_WARNING,
"Received %s packet with malformed messages",
clusterGetMessageTypeString(type));
return 1;
}
src += len;
msg_count++;
}
serverAssert(src <= end);
if (src != end) {
serverLog(LL_WARNING,
"Received %s packet with malformed messages (short)",
clusterGetMessageTypeString(type));
return 1;
}
channel = createStringObject(
(char*)hdr->data.publish.msg.bulk_data,channel_len);
/* Parse them out */
src = hdr->data.publish.msg.bulk_data + channel_len;
for (i = 0; src < end; ++i) {
memcpy(&len, src, sizeof(len));
len = ntohl(len);
src += 4;
message = createStringObject((char *) src, len);
src += len;
pubsubPublishMessage(channel, message, 0);
decrRefCount(message);
zfree(message);
}
decrRefCount(channel);
}
} else if (type == CLUSTERMSG_TYPE_FAILOVER_AUTH_REQUEST) {
if (!sender) return 1; /* We don't know that node. */
clusterSendFailoverAuthIfNeeded(sender,hdr);
Expand Down Expand Up @@ -3799,6 +3853,51 @@ clusterMsgSendBlock *clusterCreatePublishMsgBlock(robj *channel, robj *message,
return msgblock;
}

clusterMsgSendBlock *clusterCreateMPublishMsgBlock(robj *channel, robj **message, int count, uint16_t type) {

uint32_t channel_len, message_len;
uint32_t len;
size_t messages_aggregated_len;
unsigned char *end;
int i;

channel = getDecodedObject(channel);
channel_len = sdslen(channel->ptr);
size_t msglen = sizeof(clusterMsg)-sizeof(union clusterMsgData);

messages_aggregated_len = 0;
for (i = 0; i < count; i++) {
message[i] = getDecodedObject(message[i]);
message_len = sdslen(message[i]->ptr);
messages_aggregated_len += 4 + message_len;
}

msglen += sizeof(clusterMsgDataPublish) - 8 + messages_aggregated_len;
clusterMsgSendBlock *msgblock = createClusterMsgSendBlock(type, msglen);

clusterMsg *hdr = &msgblock->msg;
hdr->data.publish.msg.channel_len = htonl(channel_len);
hdr->data.publish.msg.message_len = htonl(messages_aggregated_len);
memcpy(hdr->data.publish.msg.bulk_data,channel->ptr,sdslen(channel->ptr));
end = hdr->data.publish.msg.bulk_data+channel_len;

for (i = 0; i < count; i++) {
message_len = sdslen(message[i]->ptr);
len = htonl(message_len);
memcpy(end, &len, sizeof(len));
end += sizeof(len);
memcpy(end,message[i]->ptr,message_len);
end += message_len;
}

decrRefCount(channel);
for (i = 0; i < count; i++) {
decrRefCount(message[i]);
}

return msgblock;
}

/* Send a FAIL message to all the nodes we are able to contact.
* The FAIL message is sent when we detect that a node is failing
* (CLUSTER_NODE_PFAIL) and we also receive a gossip confirmation of this:
Expand Down Expand Up @@ -3891,22 +3990,28 @@ int clusterSendModuleMessageToTarget(const char *target, uint64_t module_id, uin
* Otherwise:
* Publish this message across the slot (primary/replica).
* -------------------------------------------------------------------------- */
void clusterPropagatePublish(robj *channel, robj *message, int sharded) {
void clusterPropagatePublish(robj *channel, robj **message, int count, int sharded) {
clusterMsgSendBlock *msgblock;

if (!sharded) {
msgblock = clusterCreatePublishMsgBlock(channel, message, CLUSTERMSG_TYPE_PUBLISH);
if (count == 1) {
msgblock = clusterCreatePublishMsgBlock(channel, message[0], CLUSTERMSG_TYPE_PUBLISH);
}
else {
msgblock = clusterCreateMPublishMsgBlock(channel, message, count, CLUSTERMSG_TYPE_MPUBLISH);
}
clusterBroadcastMessage(msgblock);
clusterMsgSendBlockDecrRefCount(msgblock);
return;
}

serverAssert(count == 1);
listIter li;
listNode *ln;
list *nodes_for_slot = clusterGetNodesInMyShard(server.cluster->myself);
serverAssert(nodes_for_slot != NULL);
listRewind(nodes_for_slot, &li);
msgblock = clusterCreatePublishMsgBlock(channel, message, CLUSTERMSG_TYPE_PUBLISHSHARD);
msgblock = clusterCreatePublishMsgBlock(channel, message[0], CLUSTERMSG_TYPE_PUBLISHSHARD);
while((ln = listNext(&li))) {
clusterNode *node = listNodeValue(ln);
if (node->flags & (CLUSTER_NODE_MYSELF|CLUSTER_NODE_HANDSHAKE))
Expand Down Expand Up @@ -5560,6 +5665,7 @@ const char *clusterGetMessageTypeString(int type) {
case CLUSTERMSG_TYPE_MEET: return "meet";
case CLUSTERMSG_TYPE_FAIL: return "fail";
case CLUSTERMSG_TYPE_PUBLISH: return "publish";
case CLUSTERMSG_TYPE_MPUBLISH: return "mpublish";
case CLUSTERMSG_TYPE_PUBLISHSHARD: return "publishshard";
case CLUSTERMSG_TYPE_FAILOVER_AUTH_REQUEST: return "auth-req";
case CLUSTERMSG_TYPE_FAILOVER_AUTH_ACK: return "auth-ack";
Expand Down
3 changes: 2 additions & 1 deletion src/cluster_legacy.h
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,8 @@ typedef struct clusterNodeFailReport {
#define CLUSTERMSG_TYPE_MFSTART 8 /* Pause clients for manual failover */
#define CLUSTERMSG_TYPE_MODULE 9 /* Module cluster API message. */
#define CLUSTERMSG_TYPE_PUBLISHSHARD 10 /* Pub/Sub Publish shard propagation */
#define CLUSTERMSG_TYPE_COUNT 11 /* Total number of message types. */
#define CLUSTERMSG_TYPE_MPUBLISH 11 /* Pub/Sub Multiple Publish propagation. */
#define CLUSTERMSG_TYPE_COUNT 12 /* Total number of message types. */

/* Initially we don't know our "name", but we'll find it once we connect
* to the first node, using the getsockname() function. Then we'll use this
Expand Down
24 changes: 24 additions & 0 deletions src/commands.def
Original file line number Diff line number Diff line change
Expand Up @@ -4430,6 +4430,29 @@ struct COMMAND_ARG RPUSHX_Args[] = {
{MAKE_ARG("element",ARG_TYPE_STRING,-1,NULL,NULL,NULL,CMD_ARG_MULTIPLE,0,NULL)},
};

/********** MPUBLISH ********************/

#ifndef SKIP_CMD_HISTORY_TABLE
/* MPUBLISH history */
#define MPUBLISH_History NULL
#endif

#ifndef SKIP_CMD_TIPS_TABLE
/* MPUBLISH tips */
#define MPUBLISH_Tips NULL
#endif

#ifndef SKIP_CMD_KEY_SPECS_TABLE
/* MPUBLISH key specs */
#define MPUBLISH_Keyspecs NULL
#endif

/* MPUBLISH argument table */
struct COMMAND_ARG MPUBLISH_Args[] = {
{MAKE_ARG("channel",ARG_TYPE_STRING,-1,NULL,NULL,NULL,CMD_ARG_NONE,0,NULL)},
{MAKE_ARG("message",ARG_TYPE_STRING,-1,NULL,NULL,NULL,CMD_ARG_NONE,0,NULL)},
};

/********** PSUBSCRIBE ********************/

#ifndef SKIP_CMD_HISTORY_TABLE
Expand Down Expand Up @@ -10764,6 +10787,7 @@ struct COMMAND_STRUCT serverCommandTable[] = {
{MAKE_CMD("rpush","Appends one or more elements to a list. Creates the key if it doesn't exist.","O(1) for each element added, so O(N) to add N elements when the command is called with multiple arguments.","1.0.0",CMD_DOC_NONE,NULL,NULL,"list",COMMAND_GROUP_LIST,RPUSH_History,1,RPUSH_Tips,0,rpushCommand,-3,CMD_WRITE|CMD_DENYOOM|CMD_FAST,ACL_CATEGORY_LIST,RPUSH_Keyspecs,1,NULL,2),.args=RPUSH_Args},
{MAKE_CMD("rpushx","Appends an element to a list only when the list exists.","O(1) for each element added, so O(N) to add N elements when the command is called with multiple arguments.","2.2.0",CMD_DOC_NONE,NULL,NULL,"list",COMMAND_GROUP_LIST,RPUSHX_History,1,RPUSHX_Tips,0,rpushxCommand,-3,CMD_WRITE|CMD_DENYOOM|CMD_FAST,ACL_CATEGORY_LIST,RPUSHX_Keyspecs,1,NULL,2),.args=RPUSHX_Args},
/* pubsub */
{MAKE_CMD("mpublish","Posts multiple messages to a channel.","O(N+M+P) where N is the number of clients subscribed to the receiving channel, M is the total number of subscribed patterns (by any client) and P is the number of messages to be published.","2.0.0",CMD_DOC_NONE,NULL,NULL,"pubsub",COMMAND_GROUP_PUBSUB,MPUBLISH_History,0,MPUBLISH_Tips,0,mpublishCommand,-3,CMD_PUBSUB|CMD_LOADING|CMD_STALE|CMD_FAST|CMD_MAY_REPLICATE|CMD_SENTINEL,0,MPUBLISH_Keyspecs,0,NULL,2),.args=MPUBLISH_Args},
{MAKE_CMD("psubscribe","Listens for messages published to channels that match one or more patterns.","O(N) where N is the number of patterns to subscribe to.","2.0.0",CMD_DOC_NONE,NULL,NULL,"pubsub",COMMAND_GROUP_PUBSUB,PSUBSCRIBE_History,0,PSUBSCRIBE_Tips,0,psubscribeCommand,-2,CMD_PUBSUB|CMD_NOSCRIPT|CMD_LOADING|CMD_STALE|CMD_SENTINEL,0,PSUBSCRIBE_Keyspecs,0,NULL,1),.args=PSUBSCRIBE_Args},
{MAKE_CMD("publish","Posts a message to a channel.","O(N+M) where N is the number of clients subscribed to the receiving channel and M is the total number of subscribed patterns (by any client).","2.0.0",CMD_DOC_NONE,NULL,NULL,"pubsub",COMMAND_GROUP_PUBSUB,PUBLISH_History,0,PUBLISH_Tips,0,publishCommand,3,CMD_PUBSUB|CMD_LOADING|CMD_STALE|CMD_FAST|CMD_MAY_REPLICATE|CMD_SENTINEL,0,PUBLISH_Keyspecs,0,NULL,2),.args=PUBLISH_Args},
{MAKE_CMD("pubsub","A container for Pub/Sub commands.","Depends on subcommand.","2.8.0",CMD_DOC_NONE,NULL,NULL,"pubsub",COMMAND_GROUP_PUBSUB,PUBSUB_History,0,PUBSUB_Tips,0,NULL,-2,0,0,PUBSUB_Keyspecs,0,NULL,0),.subcommands=PUBSUB_Subcommands},
Expand Down
34 changes: 34 additions & 0 deletions src/commands/mpublish.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
{
"MPUBLISH": {
"summary": "Posts multiple messages to a channel.",
"complexity": "O(N+M+P) where N is the number of clients subscribed to the receiving channel, M is the total number of subscribed patterns (by any client) and P is the number of messages to be published.",
"group": "pubsub",
"since": "2.0.0",
"arity": 3,
"function": "mpublishCommand",
"command_flags": [
"PUBSUB",
"LOADING",
"STALE",
"FAST",
"MAY_REPLICATE",
"SENTINEL"
],
"arguments": [
{
"name": "channel",
"type": "string"
},
{
"name": "message",
"type": "string",
"multiple": true
}
],
"reply_schema": {
"description": "The number of clients that received the message. Note that in a Cluster, only clients that are connected to the same node as the publishing client are included in the count.",
"type": "integer",
"minimum": 0
}
}
}
1 change: 1 addition & 0 deletions src/db.c
Original file line number Diff line number Diff line change
Expand Up @@ -2164,6 +2164,7 @@ ChannelSpecs commands_with_channels[] = {
{psubscribeCommand, CMD_CHANNEL_PATTERN | CMD_CHANNEL_SUBSCRIBE, 1, -1},
{punsubscribeCommand, CMD_CHANNEL_PATTERN | CMD_CHANNEL_UNSUBSCRIBE, 1, -1},
{publishCommand, CMD_CHANNEL_PUBLISH, 1, 1},
{mpublishCommand, CMD_CHANNEL_PUBLISH, 1, 1},
{spublishCommand, CMD_CHANNEL_PUBLISH, 1, 1},
{NULL,0} /* Terminator. */
};
Expand Down
47 changes: 38 additions & 9 deletions src/pubsub.c
Original file line number Diff line number Diff line change
Expand Up @@ -471,11 +471,12 @@ int pubsubUnsubscribeAllPatterns(client *c, int notify) {
/*
* Publish a message to all the subscribers.
*/
int pubsubPublishMessageInternal(robj *channel, robj *message, pubsubtype type) {
int pubsubPublishMessageInternal(robj *channel, robj **message, int count, pubsubtype type) {
int receivers = 0;
dictEntry *de;
dictIterator *di;
unsigned int slot = 0;
int i;

/* Send to clients listening for that channel */
if (server.cluster_enabled && type.shard) {
Expand All @@ -488,8 +489,10 @@ int pubsubPublishMessageInternal(robj *channel, robj *message, pubsubtype type)
dictIterator *iter = dictGetIterator(clients);
while ((entry = dictNext(iter)) != NULL) {
client *c = dictGetKey(entry);
addReplyPubsubMessage(c,channel,message,*type.messageBulk);
updateClientMemUsageAndBucket(c);
for (i = 0; i < count; i++) {
addReplyPubsubMessage(c,channel,message[i],*type.messageBulk);
updateClientMemUsageAndBucket(c);
}
receivers++;
}
dictReleaseIterator(iter);
Expand All @@ -516,8 +519,10 @@ int pubsubPublishMessageInternal(robj *channel, robj *message, pubsubtype type)
dictIterator *iter = dictGetIterator(clients);
while ((entry = dictNext(iter)) != NULL) {
client *c = dictGetKey(entry);
addReplyPubsubPatMessage(c,pattern,channel,message);
updateClientMemUsageAndBucket(c);
for (i = 0; i < count; i++) {
addReplyPubsubPatMessage(c,pattern,channel,message[i]);
updateClientMemUsageAndBucket(c);
}
receivers++;
}
dictReleaseIterator(iter);
Expand All @@ -528,9 +533,14 @@ int pubsubPublishMessageInternal(robj *channel, robj *message, pubsubtype type)
return receivers;
}

/* Publish all message to all the subscribers. */
int pubsubPublishMessages(robj *channel, robj **message, int count, int sharded) {
return pubsubPublishMessageInternal(channel, message, count, sharded? pubSubShardType : pubSubType);
}

/* Publish a message to all the subscribers. */
int pubsubPublishMessage(robj *channel, robj *message, int sharded) {
return pubsubPublishMessageInternal(channel, message, sharded? pubSubShardType : pubSubType);
return pubsubPublishMessages(channel, &message, 1, sharded);
}

/*-----------------------------------------------------------------------------
Expand Down Expand Up @@ -608,13 +618,17 @@ void punsubscribeCommand(client *c) {

/* This function wraps pubsubPublishMessage and also propagates the message to cluster.
* Used by the commands PUBLISH/SPUBLISH and their respective module APIs.*/
int pubsubPublishMessageAndPropagateToCluster(robj *channel, robj *message, int sharded) {
int receivers = pubsubPublishMessage(channel, message, sharded);
int pubsubPublishMessagesAndPropagateToCluster(robj *channel, robj **message, int count, int sharded) {
int receivers = pubsubPublishMessages(channel, message, count, sharded);
if (server.cluster_enabled)
clusterPropagatePublish(channel, message, sharded);
clusterPropagatePublish(channel, message, count, sharded);
return receivers;
}

int pubsubPublishMessageAndPropagateToCluster(robj *channel, robj *message, int sharded) {
return pubsubPublishMessagesAndPropagateToCluster(channel, &message, 1, sharded);
}

/* PUBLISH <channel> <message> */
void publishCommand(client *c) {
if (server.sentinel_mode) {
Expand All @@ -628,6 +642,21 @@ void publishCommand(client *c) {
addReplyLongLong(c,receivers);
}

/* MPUBLISH <channel> [message ...] */
void mpublishCommand(client *c) {
// serverAssert(0);
if (server.sentinel_mode) {
sentinelPublishCommand(c);
return;
}

int receivers = pubsubPublishMessagesAndPropagateToCluster(c->argv[1], &c->argv[2], c->argc-2, 0);
if (!server.cluster_enabled)
forceCommandPropagation(c,PROPAGATE_REPL);
addReplyLongLong(c,receivers);
}


/* PUBSUB command for Pub/Sub introspection. */
void pubsubCommand(client *c) {
if (c->argc == 2 && !strcasecmp(c->argv[1]->ptr,"help")) {
Expand Down
1 change: 1 addition & 0 deletions src/server.h
Original file line number Diff line number Diff line change
Expand Up @@ -3658,6 +3658,7 @@ void unsubscribeCommand(client *c);
void psubscribeCommand(client *c);
void punsubscribeCommand(client *c);
void publishCommand(client *c);
void mpublishCommand(client *c);
void pubsubCommand(client *c);
void spublishCommand(client *c);
void ssubscribeCommand(client *c);
Expand Down
7 changes: 7 additions & 0 deletions tests/unit/pubsub.tcl
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,13 @@ start_server {tags {"pubsub network"}} {
$rd1 close
}

test "MPUBLISH basics" {
set rd1 [valkey_deferring_client]
assert_equal 1 [r mpublish chan1 helloword valkey]
assert_equal {message chan1 helloword} [$rd1 read]
assert_equal {message chan1 valkey} [$rd1 read]
}

test "PUBLISH/SUBSCRIBE with two clients" {
set rd1 [valkey_deferring_client]
set rd2 [valkey_deferring_client]
Expand Down

0 comments on commit fe09e22

Please sign in to comment.