Skip to content

Commit

Permalink
[feat][sdk] Support Vector add support update
Browse files Browse the repository at this point in the history
  • Loading branch information
wchuande authored and ketor committed Apr 25, 2024
1 parent 708e8cd commit f0aac63
Show file tree
Hide file tree
Showing 4 changed files with 240 additions and 9 deletions.
39 changes: 30 additions & 9 deletions src/sdk/vector/vector_add_task.cc
Original file line number Diff line number Diff line change
Expand Up @@ -40,15 +40,36 @@ Status VectorAddTask::Init() {
vector_index_ = std::move(tmp);

if (vector_index_->HasAutoIncrement()) {
auto incrementer = stub.GetAutoIncrementerManager()->GetOrCreateIndexIncrementer(vector_index_);
std::vector<int64_t> ids;
int64_t id_count = vectors_.size();
ids.reserve(id_count);
DINGO_RETURN_NOT_OK(incrementer->GetNextIds(ids, id_count));
CHECK_EQ(ids.size(), id_count);

for (auto i = 0; i < id_count; i++) {
vectors_[i].id = ids[i];
bool has_id = vectors_[0].id > 0;
for (int i = 1; i < vectors_.size(); i++) {
bool next_has_id = vectors_[i].id > 0;
if (has_id ^ next_has_id) {
return Status::InvalidArgument("vector id must be all positive or not when vector index has auto increment");
} else {
has_id = next_has_id;
}
}

if (!has_id) {
auto incrementer = stub.GetAutoIncrementerManager()->GetOrCreateIndexIncrementer(vector_index_);
std::vector<int64_t> ids;
int64_t id_count = vectors_.size();
ids.reserve(id_count);

DINGO_RETURN_NOT_OK(incrementer->GetNextIds(ids, id_count));
CHECK_EQ(ids.size(), id_count);

for (auto i = 0; i < id_count; i++) {
vectors_[i].id = ids[i];
}
}

} else {
for (auto& vector : vectors_) {
int64_t id = vector.id;
if (id <= 0) {
return Status::InvalidArgument("vector id must be positive");
}
}
}

Expand Down
3 changes: 3 additions & 0 deletions src/sdk/vector/vector_add_task.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
#include <memory>
#include <unordered_map>

#include "glog/logging.h"
#include "sdk/client_stub.h"
#include "sdk/store/store_rpc_controller.h"
#include "sdk/vector/index_service_rpc.h"
Expand All @@ -40,6 +41,8 @@ class VectorAddTask : public VectorTask {

~VectorAddTask() override = default;

Status TEST_Init() { return Init(); }

private:
Status Init() override;
void DoAsync() override;
Expand Down
1 change: 1 addition & 0 deletions test/unit_test/main.cc
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,7 @@ int main(int argc, char* argv[]) {
default_run_case += ":RegionTest.*";
default_run_case += ":StoreRpcControllerTest.*";
default_run_case += ":ThreadPoolActuatorTest.*";
default_run_case += ":SDKVectorAddTaskTest.*";
default_run_case += ":SDKVectorCommonTest.*";
default_run_case += ":SDKVectorIndexCacheKeyTest.*";
default_run_case += ":SDKVectorIndexCacheTest.*";
Expand Down
206 changes: 206 additions & 0 deletions test/unit_test/sdk/vector/test_vector_add_task.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,206 @@
// Copyright (c) 2023 dingodb.com, Inc. All Rights Reserved
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include <cstdint>
#include <memory>

#include "gtest/gtest.h"
#include "sdk/vector.h"
#include "sdk/vector/vector_add_task.h"
#include "sdk/vector/vector_common.h"
#include "test_base.h"
namespace dingodb {
namespace sdk {

class SDKVectorAddTaskTest : public TestBase {
public:
void SetUp() override {}

void TearDown() override {}
};

TEST_F(SDKVectorAddTaskTest, EmptyVectors) {
std::vector<VectorWithId> ids;
VectorAddTask task(*stub, 1, ids);
Status s = task.Run();
EXPECT_TRUE(s.IsInvalidArgument());
}

static std::shared_ptr<VectorIndex> CreateFakeVectorIndex(int64_t start_id = 0) {
std::shared_ptr<VectorIndex> vector_index;

{
std::string index_name{"test"};
int64_t schema_id{2};
std::vector<int64_t> index_and_part_ids{2, 3, 4, 5, 6};
int64_t index_id = index_and_part_ids[0];
std::vector<int64_t> range_seperator_ids = {5, 10, 20};
FlatParam flat_param{1000, dingodb::sdk::MetricType::kL2};

pb::meta::IndexDefinitionWithId index_definition_with_id;
FillVectorIndexId(index_definition_with_id.mutable_index_id(), index_id, schema_id);
auto* defination = index_definition_with_id.mutable_index_definition();
defination->set_name(index_name);
FillRangePartitionRule(defination->mutable_index_partition(), range_seperator_ids, index_and_part_ids);
defination->set_replica(3);
if (start_id > 0) {
defination->set_with_auto_incrment(true);
defination->set_auto_increment(start_id);
}

auto* index_parameter = defination->mutable_index_parameter();
index_parameter->set_index_type(pb::common::IndexType::INDEX_TYPE_VECTOR);
FillFlatParmeter(index_parameter->mutable_vector_index_parameter(), flat_param);

vector_index = std::make_shared<VectorIndex>(index_definition_with_id);
}

return vector_index;
}

TEST_F(SDKVectorAddTaskTest, InitNoAutoIncreFail) {
auto vector_index = CreateFakeVectorIndex();

EXPECT_CALL(*coordinator_proxy, GetIndexById)
.WillOnce([&](const pb::meta::GetIndexRequest& request, pb::meta::GetIndexResponse& response) {
EXPECT_EQ(request.index_id().entity_id(), vector_index->GetId());

*(response.mutable_index_definition_with_id()) = vector_index->GetIndexDefWithId();

return Status::OK();
});

std::vector<VectorWithId> ids;
for (auto i = 0; i < 10; i++) {
VectorWithId vector_with_id;
vector_with_id.id = 0;
ids.push_back(vector_with_id);
}
VectorAddTask task(*stub, vector_index->GetId(), ids);

Status s = task.TEST_Init();
EXPECT_TRUE(s.IsInvalidArgument());
}

TEST_F(SDKVectorAddTaskTest, InitNoAutoIncreSuccess) {
auto vector_index = CreateFakeVectorIndex();

EXPECT_CALL(*coordinator_proxy, GetIndexById)
.WillOnce([&](const pb::meta::GetIndexRequest& request, pb::meta::GetIndexResponse& response) {
EXPECT_EQ(request.index_id().entity_id(), vector_index->GetId());

*(response.mutable_index_definition_with_id()) = vector_index->GetIndexDefWithId();

return Status::OK();
});

std::vector<VectorWithId> ids;
for (auto i = 0; i < 10; i++) {
VectorWithId vector_with_id;
vector_with_id.id = i + 1;
ids.push_back(vector_with_id);
}
VectorAddTask task(*stub, vector_index->GetId(), ids);

Status s = task.TEST_Init();
EXPECT_TRUE(s.ok());
}

TEST_F(SDKVectorAddTaskTest, InitAutoIncreFail) {
auto vector_index = CreateFakeVectorIndex(1);

EXPECT_CALL(*coordinator_proxy, GetIndexById)
.WillOnce([&](const pb::meta::GetIndexRequest& request, pb::meta::GetIndexResponse& response) {
EXPECT_EQ(request.index_id().entity_id(), vector_index->GetId());

*(response.mutable_index_definition_with_id()) = vector_index->GetIndexDefWithId();

return Status::OK();
});

std::vector<VectorWithId> ids;
for (auto i = 0; i < 10; i++) {
VectorWithId vector_with_id;
vector_with_id.id = (i % 2 ? 0 : i + 1);
ids.push_back(vector_with_id);
}
VectorAddTask task(*stub, vector_index->GetId(), ids);

Status s = task.TEST_Init();
EXPECT_TRUE(s.IsInvalidArgument());
}

TEST_F(SDKVectorAddTaskTest, InitAutoIncreSuccess) {
auto vector_index = CreateFakeVectorIndex(1);

EXPECT_CALL(*coordinator_proxy, GetIndexById)
.WillOnce([&](const pb::meta::GetIndexRequest& request, pb::meta::GetIndexResponse& response) {
EXPECT_EQ(request.index_id().entity_id(), vector_index->GetId());

*(response.mutable_index_definition_with_id()) = vector_index->GetIndexDefWithId();

return Status::OK();
});

std::vector<VectorWithId> ids;
for (auto i = 0; i < 10; i++) {
VectorWithId vector_with_id;
vector_with_id.id = (i + 1);
ids.push_back(vector_with_id);
}
VectorAddTask task(*stub, vector_index->GetId(), ids);

Status s = task.TEST_Init();
EXPECT_TRUE(s.ok());
}

TEST_F(SDKVectorAddTaskTest, InitAutoIncreUseGenerateid) {
auto vector_index = CreateFakeVectorIndex(1);

EXPECT_CALL(*coordinator_proxy, GetIndexById)
.WillOnce([&](const pb::meta::GetIndexRequest& request, pb::meta::GetIndexResponse& response) {
EXPECT_EQ(request.index_id().entity_id(), vector_index->GetId());

*(response.mutable_index_definition_with_id()) = vector_index->GetIndexDefWithId();

return Status::OK();
});

EXPECT_CALL(*coordinator_proxy, GenerateAutoIncrement)
.WillOnce([&](const pb::meta::GenerateAutoIncrementRequest& request,
pb::meta::GenerateAutoIncrementResponse& response) {
(void)request;
response.set_start_id(1);
response.set_end_id(100);
return Status::OK();
});

int64_t count = 10;
std::vector<VectorWithId> ids;
for (auto i = 0; i < count; i++) {
VectorWithId vector_with_id;
vector_with_id.id = 0;
ids.push_back(vector_with_id);
}

VectorAddTask task(*stub, vector_index->GetId(), ids);
Status s = task.TEST_Init();
EXPECT_TRUE(s.ok());
for (auto i = 0; i < count; i++) {
EXPECT_EQ(ids[i].id, i + 1);
}
}

} // namespace sdk
} // namespace dingodb

0 comments on commit f0aac63

Please sign in to comment.