Skip to content

Commit

Permalink
[fix][sdk] Vector search without vector id
Browse files Browse the repository at this point in the history
  • Loading branch information
wchuande authored and rock-git committed Feb 2, 2024
1 parent ef496b1 commit 04755a2
Show file tree
Hide file tree
Showing 6 changed files with 74 additions and 44 deletions.
22 changes: 16 additions & 6 deletions src/example/sdk_vector_example.cc
Original file line number Diff line number Diff line change
Expand Up @@ -153,14 +153,18 @@ static void VectorAdd(bool use_index_name = false) {

static void VectorSearch(bool use_index_name = false) {
std::vector<dingodb::sdk::VectorWithId> target_vectors;
{
float init = 0.1f;
for (int i = 0; i < 5; i++) {
dingodb::sdk::Vector tmp_vector{dingodb::sdk::ValueType::kFloat, g_dimension};
tmp_vector.float_values.push_back(1.5);
tmp_vector.float_values.push_back(1.5);
tmp_vector.float_values.clear();
tmp_vector.float_values.push_back(init);
tmp_vector.float_values.push_back(init);

dingodb::sdk::VectorWithId tmp;
tmp.vector = std::move(tmp_vector);
target_vectors.push_back(std::move(tmp));

init = init + 0.1;
}

dingodb::sdk::SearchParameter param;
Expand All @@ -181,9 +185,15 @@ static void VectorSearch(bool use_index_name = false) {
DINGO_LOG(INFO) << "vector search result:" << dingodb::sdk::DumpToString(r);
}

if (!result.empty()) {
CHECK_EQ(result.size(), 1);
CHECK_EQ(result[0].vector_datas.size(), 2);
CHECK_EQ(result.size(), target_vectors.size());
for (auto i = 0; i < result.size(); i++) {
auto& search_result = result[i];
if (!search_result.vector_datas.empty()) {
CHECK_EQ(search_result.vector_datas.size(), param.topk);
}
const auto& vector_id = search_result.id;
CHECK_EQ(vector_id.id, target_vectors[i].id);
CHECK_EQ(vector_id.vector.Size(), target_vectors[i].vector.Size());
}
}

Expand Down
1 change: 1 addition & 0 deletions src/sdk/vector.h
Original file line number Diff line number Diff line change
Expand Up @@ -295,6 +295,7 @@ struct VectorWithDistance {
std::string DumpToString(const VectorWithDistance& obj);

struct SearchResult {
// TODO : maybe remove VectorWithId
VectorWithId id;
std::vector<VectorWithDistance> vector_datas;

Expand Down
6 changes: 4 additions & 2 deletions src/sdk/vector/vector_common.h
Original file line number Diff line number Diff line change
Expand Up @@ -179,8 +179,10 @@ static pb::common::ValueType ValueType2InternalValueTypePB(ValueType value_type)
}
}

static void FillVectorWithIdPB(pb::common::VectorWithId* pb, const VectorWithId& vector_with_id) {
pb->set_id(vector_with_id.id);
static void FillVectorWithIdPB(pb::common::VectorWithId* pb, const VectorWithId& vector_with_id, bool with_id = true) {
if (with_id) {
pb->set_id(vector_with_id.id);
}
auto* vector_pb = pb->mutable_vector();
const auto& vector = vector_with_id.vector;
vector_pb->set_dimension(vector.dimension);
Expand Down
20 changes: 17 additions & 3 deletions src/sdk/vector/vector_param.cc
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,24 @@ namespace dingodb {
namespace sdk {

std::string DumpToString(const Vector& obj) {
std::string float_values = fmt::format("{}", fmt::join(obj.float_values, ", "));
std::string binary_values = fmt::format("{}", fmt::join(obj.binary_values, ", "));
std::stringstream float_ss;
for (size_t i = 0; i < obj.float_values.size(); ++i) {
float_ss << obj.float_values[i];
if (i != obj.float_values.size() - 1) {
float_ss << ", ";
}
}

std::stringstream binary_ss;
for (size_t i = 0; i < obj.binary_values.size(); ++i) {
binary_ss << obj.binary_values[i];
if (i != obj.binary_values.size() - 1) {
binary_ss << ", ";
}
}

return fmt::format("Vector {{ dimension: {}, value_type: {}, float_values: [{}], binary_values: [{}] }}",
obj.dimension, ValueTypeToString(obj.value_type), float_values, binary_values);
obj.dimension, ValueTypeToString(obj.value_type), float_ss.str(), binary_ss.str());
}

std::string DumpToString(const VectorWithId& obj) {
Expand Down
64 changes: 33 additions & 31 deletions src/sdk/vector/vector_search_task.cc
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
#include <memory>

#include "common/logging.h"
#include "common/synchronization.h"
#include "glog/logging.h"
#include "proto/common.pb.h"
#include "proto/index.pb.h"
Expand All @@ -45,10 +46,6 @@ Status VectorSearchTask::Init() {

auto part_ids = vector_index_->GetPartitionIds();

for (int64_t i = 0; i < target_vectors_.size(); i++) {
CHECK(vector_id_idx_.insert({target_vectors_[i].id, i}).second) << "duplicate vector id: " << target_vectors_[i].id;
}

for (const auto& part_id : part_ids) {
next_part_ids_.emplace(part_id);
}
Expand All @@ -73,12 +70,14 @@ void VectorSearchTask::DoAsync() {
sub_tasks_count_.store(next_part_ids.size());

for (const auto& part_id : next_part_ids) {
auto sub_task = std::make_shared<VectorSearchPartTask>(stub, index_id_, part_id, search_param_, target_vectors_);
auto* sub_task = new VectorSearchPartTask(stub, index_id_, part_id, search_param_, target_vectors_);
sub_task->AsyncRun([this, sub_task](auto&& s) { SubTaskCallback(std::forward<decltype(s)>(s), sub_task); });
}
}

void VectorSearchTask::SubTaskCallback(Status status, std::shared_ptr<VectorSearchPartTask> sub_task) {
void VectorSearchTask::SubTaskCallback(Status status, VectorSearchPartTask* sub_task) {
DEFER(delete sub_task);

if (!status.ok()) {
DINGO_LOG(WARNING) << "sub_task: " << sub_task->Name() << " fail: " << status.ToString();

Expand Down Expand Up @@ -116,20 +115,8 @@ void VectorSearchTask::SubTaskCallback(Status status, std::shared_ptr<VectorSear
}

void VectorSearchTask::ConstructResultUnlocked() {
for (auto& iter : tmp_out_result_) {
auto& vec = iter.second;
std::sort(vec.begin(), vec.end(),
[](const VectorWithDistance& a, const VectorWithDistance& b) { return a.distance < b.distance; });
}

for (auto& iter : tmp_out_result_) {
int64_t vector_id = iter.first;
auto& vec_distance = iter.second;
const auto& vector_with_id = target_vectors_[vector_id];
CHECK_EQ(vector_with_id.id, vector_id);

for (const auto& vector_with_id : target_vectors_) {
VectorWithId tmp;
tmp.id = vector_with_id.id;
{
// NOTE: use copy
const Vector& to_copy = vector_with_id.vector;
Expand All @@ -140,14 +127,24 @@ void VectorSearchTask::ConstructResultUnlocked() {
}

SearchResult search(std::move(tmp));
search.vector_datas = std::move(vec_distance);

if (!search_param_.enable_range_search && search_param_.topk > 0 &&
search_param_.topk < search.vector_datas.size()) {
search.vector_datas.resize(search_param_.topk);
out_result_.push_back(std::move(search));
}

for (auto& iter : tmp_out_result_) {
auto& vec = iter.second;
std::sort(vec.begin(), vec.end(),
[](const VectorWithDistance& a, const VectorWithDistance& b) { return a.distance < b.distance; });
}

for (auto& iter : tmp_out_result_) {
int64_t idx = iter.first;
auto& vec_distance = iter.second;
if (!search_param_.enable_range_search && search_param_.topk > 0 && search_param_.topk < vec_distance.size()) {
vec_distance.resize(search_param_.topk);
}

out_result_.push_back(std::move(search));
out_result_[idx].vector_datas = std::move(vec_distance);
}
}

Expand Down Expand Up @@ -200,11 +197,15 @@ void VectorSearchPartTask::FillVectorSearchRpcRequest(pb::index::VectorSearchReq
FillRpcContext(*request->mutable_context(), region->RegionId(), region->Epoch());
FillInternalSearchParams(request->mutable_parameter(), vector_index_->GetVectorIndexType(), search_param_);
for (const auto& vector_id : target_vectors_) {
FillVectorWithIdPB(request->add_vector_with_ids(), vector_id);
// NOTE* vector_id is useless
FillVectorWithIdPB(request->add_vector_with_ids(), vector_id, false);
}
}

void VectorSearchPartTask::VectorSearchRpcCallback(const Status& status, VectorSearchRpc* rpc) {
// TODO : to remove
VLOG(kSdkVlogLevel) << "rpc: " << rpc->Method() << " request: " << rpc->Request()->DebugString()
<< " response: " << rpc->Response()->DebugString();
if (!status.ok()) {
DINGO_LOG(WARNING) << "rpc: " << rpc->Method() << " send to region: " << rpc->Request()->context().region_id()
<< " fail: " << status.ToString();
Expand All @@ -215,15 +216,16 @@ void VectorSearchPartTask::VectorSearchRpcCallback(const Status& status, VectorS
status_ = status;
}
} else {
VLOG(kSdkVlogLevel) << Name() << ", rpc: " << rpc->Method()
<< " send to region: " << rpc->Request()->context().region_id()
<< " status: " << status.ToString() << " request: " << rpc->Request()->DebugString()
<< " response: " << rpc->Response()->DebugString();
CHECK_EQ(rpc->Response()->batch_results_size(), rpc->Request()->vector_with_ids_size())
<< Name() << ", rpc: " << rpc->Method()
<< " request vector_with_ids_size: " << rpc->Request()->vector_with_ids_size()
<< " response batch_results_size: " << rpc->Response()->batch_results_size()
<< " request: " << rpc->Request()->DebugString() << " response: " << rpc->Response()->DebugString();

for (auto i = 0; i < rpc->Response()->batch_results_size(); i++) {
int64_t vector_id = rpc->Request()->vector_with_ids(i).id();
for (const auto& distancepb : rpc->Response()->batch_results(i).vector_with_distances()) {
VectorWithDistance distance = InternalVectorWithDistance2VectorWithDistance(distancepb);
search_result_[vector_id].push_back(std::move(distance));
search_result_[i].push_back(std::move(distance));
}
}
}
Expand Down
5 changes: 3 additions & 2 deletions src/sdk/vector/vector_search_task.h
Original file line number Diff line number Diff line change
Expand Up @@ -48,15 +48,15 @@ class VectorSearchTask : public VectorTask {

std::string Name() const override { return fmt::format("VectorSearchTask-{}", index_id_); }

void SubTaskCallback(Status status, std::shared_ptr<VectorSearchPartTask> sub_task);
void SubTaskCallback(Status status, VectorSearchPartTask* sub_task);

void ConstructResultUnlocked();

const int64_t index_id_;
const SearchParameter& search_param_;
const std::vector<VectorWithId>& target_vectors_;
std::unordered_map<int64_t, int64_t> vector_id_idx_;

// target_vectors_ idx to search result
std::unordered_map<int64_t, std::vector<VectorWithDistance>> tmp_out_result_;

std::vector<SearchResult>& out_result_;
Expand Down Expand Up @@ -105,6 +105,7 @@ class VectorSearchPartTask : public VectorTask {

std::unordered_map<int64_t, std::shared_ptr<Region>> next_batch_region_;

// target_vectors_ idx to search result
std::unordered_map<int64_t, std::vector<VectorWithDistance>> search_result_;

std::vector<StoreRpcController> controllers_;
Expand Down

0 comments on commit 04755a2

Please sign in to comment.