Skip to content

Commit

Permalink
Add support for array of strings as metadata field.
Browse files Browse the repository at this point in the history
  • Loading branch information
oscarlaird committed Apr 1, 2024
1 parent fa4dd6b commit 34a2eb5
Show file tree
Hide file tree
Showing 7 changed files with 129 additions and 33 deletions.
16 changes: 16 additions & 0 deletions sql/vector.sql
Original file line number Diff line number Diff line change
Expand Up @@ -377,6 +377,22 @@ CREATE OPERATOR CLASS float_pinecone_ops
OPERATOR 5 > (float8, float8),
OPERATOR 6 != (float8, float8);

-- list of strings
CREATE OPERATOR CLASS list_of_strings_pinecone_ops
DEFAULT FOR TYPE text[] USING pinecone AS
OPERATOR 7 && (anyarray, anyarray), -- overlap
OPERATOR 2 @> (anyarray, anyarray);

-- int opclass for pinecone
CREATE OPERATOR CLASS int_pinecone_ops
DEFAULT FOR TYPE int4 USING pinecone AS
OPERATOR 1 < (int4, int4),
OPERATOR 2 <= (int4, int4),
OPERATOR 3 = (int4, int4),
OPERATOR 4 >= (int4, int4),
OPERATOR 5 > (int4, int4),
OPERATOR 6 != (int4, int4);

-- we want consistent naming
-- < 1
-- <= 2
Expand Down
6 changes: 4 additions & 2 deletions src/pinecone/pinecone.c
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
#include "utils/guc.h"
#include <access/reloptions.h>

#include <float.h>


#if PG_VERSION_NUM < 150000
#define MarkGUCPrefixReserved(x) EmitWarningsOnPlaceholders(x)
Expand Down Expand Up @@ -90,8 +92,8 @@ void no_costestimate(PlannerInfo *root, IndexPath *path, double loop_count,
{
// todo: consider running a health check on the remote index and return infinity if it is not healthy
if (list_length(path->indexorderbycols) == 0 || linitial_int(path->indexorderbycols) != 0) {
elog(DEBUG1, "Index must be ordered by the first column");
*indexTotalCost = 1000000;
elog(DEBUG1, "Pinecone index must be ordered by distance. Returning infinity.");
*indexTotalCost = DBL_MAX;
return;
}
};
Expand Down
5 changes: 5 additions & 0 deletions src/pinecone/pinecone.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,10 @@
#define DEFAULT_SPEC "{}"
#define DEFAULT_HOST ""

// strategy numbers
#define PINECONE_STRATEGY_ARRAY_OVERLAP 7
#define PINECONE_STRATEGY_ARRAY_CONTAINS 2

// structs
typedef struct PineconeScanOpaqueData
{
Expand Down Expand Up @@ -209,6 +213,7 @@ cJSON* index_tuple_get_pinecone_vector(Relation index, IndexTuple itup);
cJSON* heap_tuple_get_pinecone_vector(Relation heap, HeapTuple htup);
char* pinecone_id_from_heap_tid(ItemPointerData heap_tid);
ItemPointerData pinecone_id_get_heap_tid(char *id);
cJSON* text_array_get_json(Datum value);
// read and write meta pages
PineconeStaticMetaPageData PineconeSnapshotStaticMeta(Relation index);
PineconeBufferMetaPageData PineconeSnapshotBufferMeta(Relation index);
Expand Down
23 changes: 14 additions & 9 deletions src/pinecone/pinecone_api.c
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ void set_curl_options(CURL *hnd, const char *api_key, const char *url, const cha
// Declare the CURL handle as a global variable
CURL *hnd_t;

cJSON* generic_pinecone_request(const char *api_key, const char *url, const char *method, cJSON *body) {
cJSON* generic_pinecone_request(const char *api_key, const char *url, const char *method, cJSON *body, bool expect_json_response) {
// CURL *hnd = curl_easy_init();
ResponseData response_data = {"", NULL, NULL, 0, ""};
cJSON *response_json, *error;
Expand Down Expand Up @@ -105,6 +105,11 @@ cJSON* generic_pinecone_request(const char *api_key, const char *url, const char
elog(ERROR, "curl_easy_perform() failed: %s", curl_easy_strerror(ret));
}


// parse the response
if (!expect_json_response) {
return NULL;
}
response_json = cJSON_Parse(response_data.data);

if (response_json == NULL) {
Expand All @@ -124,19 +129,19 @@ cJSON* generic_pinecone_request(const char *api_key, const char *url, const char
*/
cJSON* describe_index(const char *api_key, const char *index_name) {
char url[100] = "https://api.pinecone.io/indexes/"; strcat(url, index_name);
return generic_pinecone_request(api_key, url, "GET", NULL);
return generic_pinecone_request(api_key, url, "GET", NULL, true);
}

cJSON* pinecone_get_index_stats(const char *api_key, const char *index_host) {
cJSON* resp;
char url[100] = "https://"; strcat(url, index_host); strcat(url, "/describe_index_stats");
resp = generic_pinecone_request(api_key, url, "GET", NULL);
resp = generic_pinecone_request(api_key, url, "GET", NULL, true);
return resp;
}

cJSON* list_indexes(const char *api_key) {
cJSON* response_json;
response_json = generic_pinecone_request(api_key, "https://api.pinecone.io/indexes", "GET", NULL);
response_json = generic_pinecone_request(api_key, "https://api.pinecone.io/indexes", "GET", NULL, true);
return cJSON_GetObjectItemCaseSensitive(response_json, "indexes");
}

Expand All @@ -145,19 +150,19 @@ cJSON* pinecone_delete_vectors(const char *api_key, const char *index_host, cJSO
char url[300];
sprintf(url, "https://%s/vectors/delete", index_host);
cJSON_AddItemToObject(request, "ids", ids);
return generic_pinecone_request(api_key, url, "POST", request);
return generic_pinecone_request(api_key, url, "POST", request, true);
}

cJSON* pinecone_delete_index(const char *api_key, const char *index_name) {
char url[100] = "https://api.pinecone.io/indexes/"; strcat(url, index_name);
return generic_pinecone_request(api_key, url, "DELETE", NULL);
return generic_pinecone_request(api_key, url, "DELETE", NULL, false);
}

// delete all vectors in an index
cJSON* pinecone_delete_all(const char *api_key, const char *index_host) {
char url[300];
sprintf(url, "https://%s/vectors/delete", index_host);
return generic_pinecone_request(api_key, url, "POST", cJSON_Parse("{\"deleteAll\": true}"));
return generic_pinecone_request(api_key, url, "POST", cJSON_Parse("{\"deleteAll\": true}"), true);
}

cJSON* pinecone_list_vectors(const char *api_key, const char *index_host, int limit, char* pagination_token) {
Expand All @@ -167,7 +172,7 @@ cJSON* pinecone_list_vectors(const char *api_key, const char *index_host, int li
} else {
sprintf(url, "https://%s/vectors/list?limit=%d", index_host, limit);
}
return cJSON_GetObjectItem(generic_pinecone_request(api_key, url, "GET", NULL), "vectors");
return cJSON_GetObjectItem(generic_pinecone_request(api_key, url, "GET", NULL, true), "vectors");
}

/* name, dimension, metric
Expand All @@ -181,7 +186,7 @@ cJSON* pinecone_create_index(const char *api_key, const char *index_name, const
cJSON_AddItemToObject(request, "dimension", cJSON_CreateNumber(dimension));
cJSON_AddItemToObject(request, "metric", cJSON_CreateString(metric));
cJSON_AddItemToObject(request, "spec", spec);
return generic_pinecone_request(api_key, "https://api.pinecone.io/indexes", "POST", request);
return generic_pinecone_request(api_key, "https://api.pinecone.io/indexes", "POST", request, true);
}

CURL* multi_hnd_for_query;
Expand Down
2 changes: 1 addition & 1 deletion src/pinecone/pinecone_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ typedef struct {
size_t write_callback(char *contents, size_t size, size_t nmemb, void *userdata);
struct curl_slist *create_common_headers(const char *api_key);
void set_curl_options(CURL *hnd, const char *api_key, const char *url, const char *method, ResponseData *response_data);
cJSON* generic_pinecone_request(const char *api_key, const char *url, const char *method, cJSON *body);
cJSON* generic_pinecone_request(const char *api_key, const char *url, const char *method, cJSON *body, bool expect_json_response);
cJSON* describe_index(const char *api_key, const char *index_name);
cJSON* pinecone_get_index_stats(const char *api_key, const char *index_host);
cJSON* list_indexes(const char *api_key);
Expand Down
69 changes: 49 additions & 20 deletions src/pinecone/pinecone_scan.c
Original file line number Diff line number Diff line change
Expand Up @@ -122,33 +122,63 @@ IndexScanDesc pinecone_beginscan(Relation index, int nkeys, int norderbys)
cJSON* pinecone_build_filter(Relation index, ScanKey keys, int nkeys) {
cJSON *filter = cJSON_CreateObject();
cJSON *and_list = cJSON_CreateArray();
const char* pinecone_filter_operators[] = {"$lt", "$lte", "$eq", "$gte", "$gt", "$ne"};
const char* pinecone_filter_operators[] = {"$lt", "$lte", "$eq", "$gte", "$gt", "$ne", "$in"};
for (int i = 0; i < nkeys; i++)
{
cJSON *key_filter = cJSON_CreateObject();
cJSON *condition = cJSON_CreateObject();
cJSON *condition_value = NULL;
FormData_pg_attribute* td = TupleDescAttr(index->rd_att, keys[i].sk_attno - 1);

switch (td->atttypid)
{
case BOOLOID:
condition_value = cJSON_CreateBool(DatumGetBool(keys[i].sk_argument));
break;
case FLOAT8OID:
condition_value = cJSON_CreateNumber(DatumGetFloat8(keys[i].sk_argument));
break;
case TEXTOID:
condition_value = cJSON_CreateString(text_to_cstring(DatumGetTextP(keys[i].sk_argument)));
break;
default:
continue; // skip unsupported types
if (td->atttypid == TEXTARRAYOID && keys[i].sk_strategy == PINECONE_STRATEGY_ARRAY_CONTAINS) {
// contains (list_of_strings @> ARRAY[tag1, tag2])
// $and: [ {list_of_strings: {$in: [tag1]}}, {list_of_strings: {$in: [tag2]}} ]
cJSON* tags = text_array_get_json(keys[i].sk_argument);
cJSON* tag;
cJSON_ArrayForEach(tag, tags) {
cJSON* condition_contains_tag = cJSON_CreateObject(); // list_of_strings: {$in: [tag1]}
cJSON* predicate_contains_tag = cJSON_CreateObject(); // {$in: [tag1]}
cJSON* single_tag_list = cJSON_CreateArray(); // [tag1]
cJSON_AddItemToArray(single_tag_list, cJSON_Duplicate(tag, true)); // [tag1]
cJSON_AddItemToObject(predicate_contains_tag, "$in", single_tag_list); // {$in: [tag1]}
cJSON_AddItemToObject(condition_contains_tag, td->attname.data, predicate_contains_tag); // list_of_strings: {$in: [tag1]}
cJSON_AddItemToArray(and_list, condition_contains_tag);
}
// cJSON_Delete(tags);
} else {
switch (td->atttypid)
{
case BOOLOID:
condition_value = cJSON_CreateBool(DatumGetBool(keys[i].sk_argument));
break;
case FLOAT8OID:
condition_value = cJSON_CreateNumber(DatumGetFloat8(keys[i].sk_argument));
break;
case TEXTOID:
condition_value = cJSON_CreateString(text_to_cstring(DatumGetTextP(keys[i].sk_argument)));
break;
case TEXTARRAYOID:
// overlap
if (keys[i].sk_strategy != PINECONE_STRATEGY_ARRAY_OVERLAP) {
ereport(ERROR,
(errcode(ERRCODE_FEATURE_NOT_SUPPORTED),
errmsg("Unsupported operator for text[] datatype. Must be && (overlap)")));
}
condition_value = text_array_get_json(keys[i].sk_argument);
break;
// contains (list_of_strings @> ARRAY[tag1, tag2])
// $and: [ {list_of_strings: {$in: [tag1]}}, {list_of_strings: {$in: [tag2]}} ]
default:
ereport(ERROR,
(errcode(ERRCODE_FEATURE_NOT_SUPPORTED),
errmsg("Unsupported datatype for pinecone index scan. Must be one of bool, float8, text, text[]")));
break;
}
// this only works if all datatypes use the same strategy naming convention. todo: document this
cJSON_AddItemToObject(condition, pinecone_filter_operators[keys[i].sk_strategy - 1], condition_value);
cJSON_AddItemToObject(key_filter, td->attname.data, condition);
cJSON_AddItemToArray(and_list, key_filter);
}

// this only works if all datatypes use the same strategy naming convention. todo: document this
cJSON_AddItemToObject(condition, pinecone_filter_operators[keys[i].sk_strategy - 1], condition_value);
cJSON_AddItemToObject(key_filter, td->attname.data, condition);
cJSON_AddItemToArray(and_list, key_filter);
}
cJSON_AddItemToObject(filter, "$and", and_list);
return filter;
Expand Down Expand Up @@ -185,7 +215,6 @@ void pinecone_rescan(IndexScanDesc scan, ScanKey keys, int nkeys, ScanKey orderb

// build the filter
filter = pinecone_build_filter(scan->indexRelation, keys, nkeys);
elog(DEBUG1, "filter: %s", cJSON_Print(filter));

// get the query vector
query_datum = orderbys[0].sk_argument;
Expand Down
41 changes: 40 additions & 1 deletion src/pinecone/pinecone_utils.c
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
#include "access/generic_xlog.h"
#include "access/relscan.h"
#include "utils/builtins.h"
#include "utils/lsyscache.h"

cJSON* tuple_get_pinecone_vector(TupleDesc tup_desc, Datum *values, bool *isnull, char *vector_id)
{
Expand All @@ -26,13 +27,22 @@ cJSON* tuple_get_pinecone_vector(TupleDesc tup_desc, Datum *values, bool *isnull
case FLOAT8OID:
cJSON_AddItemToObject(metadata, NameStr(td->attname), cJSON_CreateNumber(DatumGetFloat8(values[i])));
break;
case INT4OID:
cJSON_AddItemToObject(metadata, NameStr(td->attname), cJSON_CreateNumber(DatumGetInt32(values[i])));
break;
case TEXTOID:
cJSON_AddItemToObject(metadata, NameStr(td->attname), cJSON_CreateString(text_to_cstring(DatumGetTextP(values[i]))));
break;
case TEXTARRAYOID:
{
cJSON* json_array = text_array_get_json(values[i]);
cJSON_AddItemToObject(metadata, NameStr(td->attname), json_array);
}
break;
default:
ereport(ERROR, (errcode(ERRCODE_INVALID_PARAMETER_VALUE),
errmsg("Invalid column type when decoding tuple."),
errhint("Pinecone index only supports boolean, float8 and text columns")));
errhint("Pinecone index only supports boolean, float8, text, and textarray columns")));
}
}
// add to vector object
Expand Down Expand Up @@ -241,4 +251,33 @@ hash_tid(ItemPointerData tid, int seed)
x.tid = tid;

return murmurhash64(x.i + seed);
}

/* text_array_get_json */
cJSON* text_array_get_json(Datum value) {
ArrayType *array = DatumGetArrayTypeP(value);
int nelems = ArrayGetNItems(ARR_NDIM(array), ARR_DIMS(array));
Datum* elems;
bool* nulls;
int16 elmlen;
bool elmbyval;
char elmalign;
Oid elmtype = ARR_ELEMTYPE(array);
cJSON *json_array = cJSON_CreateArray();

// get array element type info
get_typlenbyvalalign(elmtype, &elmlen, &elmbyval, &elmalign);

// deconstruct array
deconstruct_array(array, elmtype, elmlen, elmbyval, elmalign, &elems, &nulls, &nelems);

// copy array elements to json array
for (int j = 0; j < nelems; j++) {
if (!nulls[j]) {
Datum elem = elems[j];
char* cstr = TextDatumGetCString(elem);
cJSON_AddItemToArray(json_array, cJSON_CreateString(cstr));
}
}
return json_array;
}

0 comments on commit 34a2eb5

Please sign in to comment.