Skip to content

Commit

Permalink
[feat][sdk] Support Vector add with auto incre id
Browse files Browse the repository at this point in the history
  • Loading branch information
wchuande authored and ketor committed Apr 24, 2024
1 parent acfa4a3 commit 26b3e4b
Show file tree
Hide file tree
Showing 32 changed files with 660 additions and 234 deletions.
2 changes: 1 addition & 1 deletion src/benchmark/operation.cc
Original file line number Diff line number Diff line change
Expand Up @@ -505,7 +505,7 @@ static bool IsMonotoneIncreasing(const std::vector<sdk::VectorWithId>& vector_wi
}

Operation::Result BaseOperation::VectorPut(VectorIndexEntryPtr entry,
const std::vector<sdk::VectorWithId>& vector_with_ids) {
std::vector<sdk::VectorWithId>& vector_with_ids) {
Operation::Result result;

IsMonotoneIncreasing(vector_with_ids);
Expand Down
2 changes: 1 addition & 1 deletion src/benchmark/operation.h
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ class BaseOperation : public Operation {
Result KvTxnGet(const std::vector<std::string>& keys);
Result KvTxnBatchGet(const std::vector<std::vector<std::string>>& keys);

Result VectorPut(VectorIndexEntryPtr entry, const std::vector<sdk::VectorWithId>& vector_with_ids);
Result VectorPut(VectorIndexEntryPtr entry, std::vector<sdk::VectorWithId>& vector_with_ids);
Result VectorSearch(VectorIndexEntryPtr entry, const std::vector<sdk::VectorWithId>& vector_with_ids,
const sdk::SearchParam& search_param);

Expand Down
127 changes: 79 additions & 48 deletions src/example/sdk_vector_example.cc
Original file line number Diff line number Diff line change
Expand Up @@ -62,8 +62,6 @@ static void PrepareVectorIndex() {
.SetReplicaNum(3)
.SetRangePartitions(g_range_partition_seperator_ids)
.SetFlatParam(g_flat_param)
.SetAutoIncrement(true)
.SetAutoIncrementStart(1)
.SetScalarSchema(schema)
.Create(g_index_id);
DINGO_LOG(INFO) << "Create index status: " << create.ToString() << ", index_id:" << g_index_id;
Expand All @@ -87,51 +85,6 @@ void PostClean(bool use_index_name = false) {
g_vector_ids.clear();
}

// TODO: remove
static void VectorIndexCacheSearch() {
auto coordinator_proxy = std::make_shared<dingodb::sdk::CoordinatorProxy>();
Status open = coordinator_proxy->Open(FLAGS_coordinator_url);
CHECK(open.ok()) << "Fail to open coordinator_proxy, please check parameter --url=" << FLAGS_coordinator_url;

dingodb::sdk::VectorIndexCache cache(*coordinator_proxy);
{
std::shared_ptr<dingodb::sdk::VectorIndex> index;
Status got = cache.GetVectorIndexById(g_index_id, index);
CHECK(got.ok()) << "Fail to get vector index, index_id:" << g_index_id << ", status:" << got.ToString();
CHECK(index.get() != nullptr);
CHECK_EQ(index->GetId(), g_index_id);
CHECK_EQ(index->GetName(), g_index_name);
}

{
std::shared_ptr<dingodb::sdk::VectorIndex> index;
Status got = cache.GetVectorIndexByKey(dingodb::sdk::EncodeVectorIndexCacheKey(g_schema_id, g_index_name), index);
CHECK(got.ok()) << "Fail to get vector index, index_name:" << g_index_name << ", status:" << got.ToString();
CHECK(index.get() != nullptr);
CHECK_EQ(index->GetId(), g_index_id);
CHECK_EQ(index->GetName(), g_index_name);
}

{
int64_t index_id{0};
Status got = cache.GetIndexIdByKey(dingodb::sdk::EncodeVectorIndexCacheKey(g_schema_id, g_index_name), index_id);
CHECK(got.ok()) << "Fail to get index_id, index_name" << g_index_name << ", status:" << got.ToString();
CHECK_EQ(index_id, g_index_id);
}

{
cache.RemoveVectorIndexById(g_index_id);
{
std::shared_ptr<dingodb::sdk::VectorIndex> index;
Status got = cache.GetVectorIndexByKey(dingodb::sdk::EncodeVectorIndexCacheKey(g_schema_id, g_index_name), index);
CHECK(got.ok()) << "Fail to get vector index, index_name:" << g_index_name << ", status:" << got.ToString();
CHECK(index.get() != nullptr);
CHECK_EQ(index->GetId(), g_index_id);
CHECK_EQ(index->GetName(), g_index_name);
}
}
}

static void PrepareVectorClient() {
dingodb::sdk::VectorClient* client;
Status built = g_client->NewVectorClient(&client);
Expand Down Expand Up @@ -627,6 +580,59 @@ static void VectorDelete(bool use_index_name = false) {
}
}

static void VectorAddWithAutoId(int64_t start_id) {
std::vector<int64_t> vector_ids;
Status add;
{
std::vector<dingodb::sdk::VectorWithId> vectors;

float delta = 0.1;
int64_t count = 5;
for (auto id = start_id; id < start_id + count; id++) {
dingodb::sdk::Vector tmp_vector{dingodb::sdk::ValueType::kFloat, g_dimension};
tmp_vector.float_values.push_back(1.0 + delta);
tmp_vector.float_values.push_back(2.0 + delta);

dingodb::sdk::VectorWithId tmp(0, std::move(tmp_vector));
vectors.push_back(std::move(tmp));

vector_ids.push_back(id);

delta++;
}

add = g_vector_client->AddByIndexId(g_index_id, vectors, false, false);

DINGO_LOG(INFO) << "vector add:" << add.ToString();
}

{
dingodb::sdk::ScanQueryParam param;
param.vector_id_start = 1;
param.vector_id_end = 100;
param.max_scan_count = 100;

dingodb::sdk::ScanQueryResult result;
Status tmp = g_vector_client->ScanQueryByIndexId(g_index_id, param, result);

DINGO_LOG(INFO) << "vector forward scan query: " << tmp.ToString() << ", result:" << result.ToString();
if (tmp.ok()) {
std::vector<int64_t> target_ids;
target_ids.reserve(result.vectors.size());
for (auto& vector : result.vectors) {
target_ids.push_back(vector.id);
}

if (add.ok()) {
// sort vecotor_ids and target_ids, and check equal
std::sort(vector_ids.begin(), vector_ids.end());
std::sort(target_ids.begin(), target_ids.end());
CHECK(std::equal(vector_ids.begin(), vector_ids.end(), target_ids.begin(), target_ids.end()));
}
}
}
}

int main(int argc, char* argv[]) {
FLAGS_minloglevel = google::GLOG_INFO;
FLAGS_logtostdout = true;
Expand All @@ -653,7 +659,6 @@ int main(int argc, char* argv[]) {

{
PrepareVectorIndex();
VectorIndexCacheSearch();
PrepareVectorClient();

VectorAdd();
Expand Down Expand Up @@ -687,4 +692,30 @@ int main(int argc, char* argv[]) {

PostClean(true);
}

{
int64_t start_id = 1;

{
dingodb::sdk::VectorIndexCreator* creator;
Status built = g_client->NewVectorIndexCreator(&creator);
CHECK(built.IsOK()) << "dingo creator build fail:" << built.ToString();
CHECK_NOTNULL(creator);
dingodb::ScopeGuard guard([&]() { delete creator; });

Status create = creator->SetSchemaId(g_schema_id)
.SetName(g_index_name)
.SetReplicaNum(3)
.SetRangePartitions(g_range_partition_seperator_ids)
.SetFlatParam(g_flat_param)
.SetAutoIncrementStart(start_id)
.Create(g_index_id);
DINGO_LOG(INFO) << "Create index status: " << create.ToString() << ", index_id:" << g_index_id;
sleep(20);
}

PrepareVectorClient();
VectorAddWithAutoId(start_id);
PostClean(true);
}
}
1 change: 1 addition & 0 deletions src/sdk/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@

add_library(sdk
admin_tool.cc
auto_increment_manager.cc
client_stub.cc
client.cc
coordinator_proxy.cc
Expand Down
129 changes: 129 additions & 0 deletions src/sdk/auto_increment_manager.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
// Copyright (c) 2023 dingodb.com, Inc. All Rights Reserved
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "sdk/auto_increment_manager.h"

#include <condition_variable>
#include <cstdint>
#include <memory>
#include <mutex>
#include <utility>

#include "common/logging.h"
#include "glog/logging.h"
#include "proto/common.pb.h"
#include "proto/meta.pb.h"
#include "sdk/client_stub.h"
#include "sdk/status.h"

namespace dingodb {

namespace sdk {

struct AutoInrementer::Req {
explicit Req() = default;
std::condition_variable cv;
};

Status AutoInrementer::GetNextId(int64_t& next) {
std::vector<int64_t> ids;
DINGO_RETURN_NOT_OK(GetNextIds(ids, 1));
CHECK(!ids.empty());
next = ids.front();
return Status::OK();
}

Status AutoInrementer::GetNextIds(std::vector<int64_t>& to_fill, int64_t count) {
CHECK_GT(count, 0);
Req req;

{
std::unique_lock<std::mutex> lk(mutex_);
queue_.push_back(&req);
while (&req != queue_.front()) {
req.cv.wait(lk);
}
}

Status s;
while (s.ok() && count > 0) {
if (id_cache_.size() < count) {
s = RefillCache();
} else {
to_fill.insert(to_fill.end(), id_cache_.begin(), id_cache_.begin() + count);
id_cache_.erase(id_cache_.begin(), id_cache_.begin() + count);
count = 0;
}
}

{
std::unique_lock<std::mutex> lk(mutex_);
queue_.pop_front();
if (!queue_.empty()) {
queue_.front()->cv.notify_one();
}
}

return s;
}

Status AutoInrementer::RefillCache() {
pb::meta::GenerateAutoIncrementRequest request;
PrepareRequest(request);
pb::meta::GenerateAutoIncrementResponse response;
Status s = stub_.GetCoordinatorProxy()->GenerateAutoIncrement(request, response);
VLOG(kSdkVlogLevel) << "GenerateAutoIncrement request:" << request.DebugString()
<< " response:" << response.DebugString();

DINGO_RETURN_NOT_OK(s);
// TODO: maybe not crash just return error msg
CHECK_GT(response.end_id(), response.start_id())
<< " request:" << request.DebugString() << " response: " << response.DebugString();
for (int64_t i = response.start_id(); i < response.end_id(); i++) {
id_cache_.push_back(i);
}
return Status::OK();
}

void IndexAutoInrementer::PrepareRequest(pb::meta::GenerateAutoIncrementRequest& request) {
*request.mutable_table_id() = vector_index_->GetIndexDefWithId().index_id();
request.set_count(FLAGS_auto_incre_req_count);
request.set_auto_increment_increment(1);
request.set_auto_increment_offset(vector_index_->GetIncrementStartId());
}

std::shared_ptr<AutoInrementer> AutoIncrementerManager::GetOrCreateIndexIncrementer(
std::shared_ptr<VectorIndex>& index) {
std::unique_lock<std::mutex> lk(mutex_);
int64_t index_id = index->GetId();
auto iter = auto_incrementer_map_.find(index_id);
if (iter != auto_incrementer_map_.end()) {
return iter->second;
} else {
auto incrementer = std::make_shared<IndexAutoInrementer>(stub_, index);
CHECK(auto_incrementer_map_.emplace(std::make_pair(index_id, incrementer)).second);
return incrementer;
}
}

void AutoIncrementerManager::RemoveIndexIncrementerById(int64_t index_id) {
std::unique_lock<std::mutex> lk(mutex_);
auto iter = auto_incrementer_map_.find(index_id);
if (iter != auto_incrementer_map_.end()) {
auto_incrementer_map_.erase(iter);
}
}

} // namespace sdk
} // namespace dingodb
Loading

0 comments on commit 26b3e4b

Please sign in to comment.