diff --git a/src/t_string.c b/src/t_string.c index 1c90eabf3e..55a565f10b 100644 --- a/src/t_string.c +++ b/src/t_string.c @@ -76,6 +76,7 @@ static int checkStringLength(client *c, long long size, long long append) { #define OBJ_EXAT (1 << 6) /* Set if timestamp in second is given */ #define OBJ_PXAT (1 << 7) /* Set if timestamp in ms is given */ #define OBJ_PERSIST (1 << 8) /* Set if we need to remove the ttl */ +#define OBJ_SET_IFEQ (1 << 9) /* Set if we need compare and set */ /* Forward declaration */ static int getExpireMillisecondsOrReply(client *c, robj *expire, int flags, int unit, long long *milliseconds); @@ -87,7 +88,8 @@ void setGenericCommand(client *c, robj *expire, int unit, robj *ok_reply, - robj *abort_reply) { + robj *abort_reply, + robj *comparison) { long long milliseconds = 0; /* initialized to avoid any harmness warning */ int found = 0; int setkey_flags = 0; @@ -102,6 +104,15 @@ void setGenericCommand(client *c, found = (lookupKeyWrite(c->db, key) != NULL); + /* Handle the IFEQ conditional check */ + if ((flags & OBJ_SET_IFEQ) && found) { + robj *current_value = lookupKeyRead(c->db, key); + if (current_value == NULL || compareStringObjects(current_value, comparison) != 0) { + addReply(c, abort_reply ? abort_reply : shared.null[c->resp]); + return; + } + } + if ((flags & OBJ_SET_NX && found) || (flags & OBJ_SET_XX && !found)) { if (!(flags & OBJ_SET_GET)) { addReply(c, abort_reply ? abort_reply : shared.null[c->resp]); @@ -208,7 +219,7 @@ static int getExpireMillisecondsOrReply(client *c, robj *expire, int flags, int * string arguments used in SET and GET command. * * Get specific commands - PERSIST/DEL - * Set specific commands - XX/NX/GET + * Set specific commands - XX/NX/GET/IFEQ * Common commands - EX/EXAT/PX/PXAT/KEEPTTL * * Function takes pointers to client, flags, unit, pointer to pointer of expire obj if needed @@ -219,7 +230,7 @@ static int getExpireMillisecondsOrReply(client *c, robj *expire, int flags, int * Input flags are updated upon parsing the arguments. Unit and expire are updated if there are any * EX/EXAT/PX/PXAT arguments. Unit is updated to millisecond if PX/PXAT is set. */ -int parseExtendedStringArgumentsOrReply(client *c, int *flags, int *unit, robj **expire, int command_type) { +int parseExtendedStringArgumentsOrReply(client *c, int *flags, int *unit, robj **expire, robj **compare_val, int command_type) { int j = command_type == COMMAND_GET ? 2 : 3; for (; j < c->argc; j++) { char *opt = c->argv[j]->ptr; @@ -295,7 +306,17 @@ int parseExtendedStringArgumentsOrReply(client *c, int *flags, int *unit, robj * *unit = UNIT_MILLISECONDS; *expire = next; j++; - } else { + } else if ((opt[0] == 'i' || opt[0] == 'I') && + (opt[1] == 'f' || opt[1] == 'F') && + (opt[2] == 'e' || opt[2] == 'E') && + (opt[3] == 'q' || opt[3] == 'Q') && opt[4] == '\0' && + next && (command_type == COMMAND_SET)) + { + *flags |= OBJ_SET_IFEQ; + *compare_val = next; + j++; + } + else { addReplyErrorObject(c,shared.syntaxerr); return C_ERR; } @@ -308,30 +329,31 @@ int parseExtendedStringArgumentsOrReply(client *c, int *flags, int *unit, robj * * [EXAT ][PXAT ] */ void setCommand(client *c) { robj *expire = NULL; + robj *comparison = NULL; int unit = UNIT_SECONDS; int flags = OBJ_NO_FLAGS; - if (parseExtendedStringArgumentsOrReply(c, &flags, &unit, &expire, COMMAND_SET) != C_OK) { + if (parseExtendedStringArgumentsOrReply(c, &flags, &unit, &expire, &comparison, COMMAND_SET) != C_OK) { return; } c->argv[2] = tryObjectEncoding(c->argv[2]); - setGenericCommand(c, flags, c->argv[1], c->argv[2], expire, unit, NULL, NULL); + setGenericCommand(c, flags, c->argv[1], c->argv[2], expire, unit, NULL, NULL, comparison); } void setnxCommand(client *c) { c->argv[2] = tryObjectEncoding(c->argv[2]); - setGenericCommand(c, OBJ_SET_NX, c->argv[1], c->argv[2], NULL, 0, shared.cone, shared.czero); + setGenericCommand(c, OBJ_SET_NX, c->argv[1], c->argv[2], NULL, 0, shared.cone, shared.czero, NULL); } void setexCommand(client *c) { c->argv[3] = tryObjectEncoding(c->argv[3]); - setGenericCommand(c, OBJ_EX, c->argv[1], c->argv[3], c->argv[2], UNIT_SECONDS, NULL, NULL); + setGenericCommand(c, OBJ_EX, c->argv[1], c->argv[3], c->argv[2], UNIT_SECONDS, NULL, NULL, NULL); } void psetexCommand(client *c) { c->argv[3] = tryObjectEncoding(c->argv[3]); - setGenericCommand(c, OBJ_PX, c->argv[1], c->argv[3], c->argv[2], UNIT_MILLISECONDS, NULL, NULL); + setGenericCommand(c, OBJ_PX, c->argv[1], c->argv[3], c->argv[2], UNIT_MILLISECONDS, NULL, NULL, NULL); } int getGenericCommand(client *c) { @@ -374,10 +396,11 @@ void getCommand(client *c) { */ void getexCommand(client *c) { robj *expire = NULL; + robj *comparison = NULL; int unit = UNIT_SECONDS; int flags = OBJ_NO_FLAGS; - if (parseExtendedStringArgumentsOrReply(c, &flags, &unit, &expire, COMMAND_GET) != C_OK) { + if (parseExtendedStringArgumentsOrReply(c, &flags, &unit, &expire, &comparison, COMMAND_GET) != C_OK) { return; } diff --git a/tests/unit/type/string.tcl b/tests/unit/type/string.tcl index d7969b5b3e..31e88674d7 100644 --- a/tests/unit/type/string.tcl +++ b/tests/unit/type/string.tcl @@ -582,6 +582,19 @@ if {[string match {*jemalloc*} [s mem_allocator]]} { set err1 } {*WRONGTYPE*} + test "SET with IFEQ conditional" { + # Setting an initial value for the key + r set foo "initial_value" + + # Trying to set the key only if the value is exactly "initial_value" + assert_equal OK [r set foo "new_value" ifeq "initial_value"] + assert_equal "new_value" [r get foo] + + # Trying to set the key only if the value is NOT "initial_value" + assert_equal {} [r set foo "should_not_set" ifeq "wrong_value"] + assert_equal "new_value" [r get foo] + } + test {Extended SET EX option} { r del foo r set foo bar ex 10