Skip to content

Commit

Permalink
set with ifeq support
Browse files Browse the repository at this point in the history
Signed-off-by: Sarthak Aggarwal <sarthagg@amazon.com>
  • Loading branch information
sarthakaggarwal97 committed Nov 19, 2024
1 parent ee386c9 commit 6c216c1
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 9 deletions.
41 changes: 32 additions & 9 deletions src/t_string.c
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ static int checkStringLength(client *c, long long size, long long append) {
#define OBJ_EXAT (1 << 6) /* Set if timestamp in second is given */
#define OBJ_PXAT (1 << 7) /* Set if timestamp in ms is given */
#define OBJ_PERSIST (1 << 8) /* Set if we need to remove the ttl */
#define OBJ_SET_IFEQ (1 << 9) /* Set if we need compare and set */

/* Forward declaration */
static int getExpireMillisecondsOrReply(client *c, robj *expire, int flags, int unit, long long *milliseconds);
Expand All @@ -87,7 +88,8 @@ void setGenericCommand(client *c,
robj *expire,
int unit,
robj *ok_reply,
robj *abort_reply) {
robj *abort_reply,
robj *comparison) {
long long milliseconds = 0; /* initialized to avoid any harmness warning */
int found = 0;
int setkey_flags = 0;
Expand All @@ -102,6 +104,15 @@ void setGenericCommand(client *c,

found = (lookupKeyWrite(c->db, key) != NULL);

/* Handle the IFEQ conditional check */
if ((flags & OBJ_SET_IFEQ) && found) {
robj *current_value = lookupKeyRead(c->db, key);
if (current_value == NULL || compareStringObjects(current_value, comparison) != 0) {
addReply(c, abort_reply ? abort_reply : shared.null[c->resp]);
return;
}
}

if ((flags & OBJ_SET_NX && found) || (flags & OBJ_SET_XX && !found)) {
if (!(flags & OBJ_SET_GET)) {
addReply(c, abort_reply ? abort_reply : shared.null[c->resp]);
Expand Down Expand Up @@ -219,7 +230,7 @@ static int getExpireMillisecondsOrReply(client *c, robj *expire, int flags, int
* Input flags are updated upon parsing the arguments. Unit and expire are updated if there are any
* EX/EXAT/PX/PXAT arguments. Unit is updated to millisecond if PX/PXAT is set.
*/
int parseExtendedStringArgumentsOrReply(client *c, int *flags, int *unit, robj **expire, int command_type) {
int parseExtendedStringArgumentsOrReply(client *c, int *flags, int *unit, robj **expire, robj **compare_val, int command_type) {
int j = command_type == COMMAND_GET ? 2 : 3;
for (; j < c->argc; j++) {
char *opt = c->argv[j]->ptr;
Expand Down Expand Up @@ -295,7 +306,17 @@ int parseExtendedStringArgumentsOrReply(client *c, int *flags, int *unit, robj *
*unit = UNIT_MILLISECONDS;
*expire = next;
j++;
} else {
} else if ((opt[0] == 'i' || opt[0] == 'I') &&
(opt[1] == 'f' || opt[1] == 'F') &&
(opt[2] == 'e' || opt[2] == 'E') &&
(opt[3] == 'q' || opt[3] == 'Q') && opt[4] == '\0' &&
next && (command_type == COMMAND_SET))
{
*flags |= OBJ_SET_IFEQ;
*compare_val = next;
j++;
}
else {
addReplyErrorObject(c,shared.syntaxerr);
return C_ERR;
}
Expand All @@ -308,30 +329,31 @@ int parseExtendedStringArgumentsOrReply(client *c, int *flags, int *unit, robj *
* [EXAT <seconds-timestamp>][PXAT <milliseconds-timestamp>] */
void setCommand(client *c) {
robj *expire = NULL;
robj *comparison = NULL;
int unit = UNIT_SECONDS;
int flags = OBJ_NO_FLAGS;

if (parseExtendedStringArgumentsOrReply(c, &flags, &unit, &expire, COMMAND_SET) != C_OK) {
if (parseExtendedStringArgumentsOrReply(c, &flags, &unit, &expire, &comparison, COMMAND_SET) != C_OK) {
return;
}

c->argv[2] = tryObjectEncoding(c->argv[2]);
setGenericCommand(c, flags, c->argv[1], c->argv[2], expire, unit, NULL, NULL);
setGenericCommand(c, flags, c->argv[1], c->argv[2], expire, unit, NULL, NULL, comparison);
}

void setnxCommand(client *c) {
c->argv[2] = tryObjectEncoding(c->argv[2]);
setGenericCommand(c, OBJ_SET_NX, c->argv[1], c->argv[2], NULL, 0, shared.cone, shared.czero);
setGenericCommand(c, OBJ_SET_NX, c->argv[1], c->argv[2], NULL, 0, shared.cone, shared.czero, NULL);
}

void setexCommand(client *c) {
c->argv[3] = tryObjectEncoding(c->argv[3]);
setGenericCommand(c, OBJ_EX, c->argv[1], c->argv[3], c->argv[2], UNIT_SECONDS, NULL, NULL);
setGenericCommand(c, OBJ_EX, c->argv[1], c->argv[3], c->argv[2], UNIT_SECONDS, NULL, NULL, NULL);
}

void psetexCommand(client *c) {
c->argv[3] = tryObjectEncoding(c->argv[3]);
setGenericCommand(c, OBJ_PX, c->argv[1], c->argv[3], c->argv[2], UNIT_MILLISECONDS, NULL, NULL);
setGenericCommand(c, OBJ_PX, c->argv[1], c->argv[3], c->argv[2], UNIT_MILLISECONDS, NULL, NULL, NULL);
}

int getGenericCommand(client *c) {
Expand Down Expand Up @@ -374,10 +396,11 @@ void getCommand(client *c) {
*/
void getexCommand(client *c) {
robj *expire = NULL;
robj *comparison = NULL;
int unit = UNIT_SECONDS;
int flags = OBJ_NO_FLAGS;

if (parseExtendedStringArgumentsOrReply(c, &flags, &unit, &expire, COMMAND_GET) != C_OK) {
if (parseExtendedStringArgumentsOrReply(c, &flags, &unit, &expire, &comparison, COMMAND_GET) != C_OK) {
return;
}

Expand Down
13 changes: 13 additions & 0 deletions tests/unit/type/string.tcl
Original file line number Diff line number Diff line change
Expand Up @@ -582,6 +582,19 @@ if {[string match {*jemalloc*} [s mem_allocator]]} {
set err1
} {*WRONGTYPE*}

test "SET with IFEQ conditional" {
# Setting an initial value for the key
r set foo "initial_value"

# Trying to set the key only if the value is exactly "initial_value"
assert_equal OK [r set foo "new_value" ifeq "initial_value"]
assert_equal "new_value" [r get foo]

# Trying to set the key only if the value is NOT "initial_value"
assert_equal {} [r set foo "should_not_set" ifeq "wrong_value"]
assert_equal "new_value" [r get foo]
}

test {Extended SET EX option} {
r del foo
r set foo bar ex 10
Expand Down

0 comments on commit 6c216c1

Please sign in to comment.