diff --git a/src/impl/TypeUtils.cpp b/src/impl/TypeUtils.cpp index ae213d0..9c85ead 100644 --- a/src/impl/TypeUtils.cpp +++ b/src/impl/TypeUtils.cpp @@ -393,6 +393,12 @@ IndexTypeCast(const std::string& type) { if (type == "BIN_IVF_FLAT") { return IndexType::BIN_IVF_FLAT; } + if (type == "GPU_IVF_FLAT") { + return IndexType::GPU_IVF_FLAT; + } + if (type == "GPU_IVF_PQ") { + return IndexType::GPU_IVF_PQ; + } return IndexType::INVALID; } @@ -897,6 +903,10 @@ to_string(milvus::IndexType index_type) { return "BIN_FLAT"; case milvus::IndexType::BIN_IVF_FLAT: return "BIN_IVF_FLAT"; + case milvus::IndexType::GPU_IVF_FLAT: + return "GPU_IVF_FLAT"; + case milvus::IndexType::GPU_IVF_PQ: + return "GPU_IVF_PQ"; default: return "INVALID"; } diff --git a/src/impl/types/IndexDesc.cpp b/src/impl/types/IndexDesc.cpp index 7693966..d95fce4 100644 --- a/src/impl/types/IndexDesc.cpp +++ b/src/impl/types/IndexDesc.cpp @@ -57,6 +57,7 @@ Status validate_index_and_metric(const MetricType metric_type, const IndexType index_type) { if ((metric_type == milvus::MetricType::IP || metric_type == milvus::MetricType::L2) && (index_type == milvus::IndexType::FLAT || index_type == milvus::IndexType::IVF_FLAT || + index_type == milvus::IndexType::GPU_IVF_FLAT || index_type == milvus::IndexType::GPU_IVF_PQ || index_type == milvus::IndexType::IVF_SQ8 || index_type == milvus::IndexType::IVF_PQ || index_type == milvus::IndexType::HNSW || index_type == milvus::IndexType::IVF_HNSW || index_type == milvus::IndexType::RHNSW_FLAT || index_type == milvus::IndexType::RHNSW_SQ || @@ -90,6 +91,12 @@ validate_params(const IndexDesc& data, const std::unordered_mapCreateIndex(collection_name, index_desc, progress_monitor); EXPECT_FALSE(status.IsOk()); +} + +TEST_F(MilvusMockedTest, TestCreateGPUIndexInstantly) { + milvus::ConnectParam connect_param{"127.0.0.1", server_.ListenPort()}; + client_->Connect(connect_param); + + std::string collection_name = "test_collection"; + std::string field_name = "test_field"; + std::string index_name = "test_gpu_index"; + auto index_type = milvus::IndexType::GPU_IVF_FLAT; + auto metric_type = milvus::MetricType::IP; + int64_t index_id = 0; + + milvus::IndexDesc index_desc(field_name, "", index_type, metric_type, index_id); + index_desc.AddExtraParam("nlist", 1024); + const auto progress_monitor = ::milvus::ProgressMonitor::NoWait(); + + EXPECT_CALL(service_, Flush(_, AllOf(Property(&FlushRequest::collection_names, ElementsAre(collection_name))), _)) + .WillOnce([&](::grpc::ServerContext*, const FlushRequest*, FlushResponse*) { return ::grpc::Status{}; }); + + EXPECT_CALL(service_, CreateIndex(_, + AllOf(Property(&CreateIndexRequest::collection_name, collection_name), + Property(&CreateIndexRequest::field_name, field_name)), + _)) + .WillOnce([](::grpc::ServerContext*, const CreateIndexRequest*, ::milvus::proto::common::Status*) { + return ::grpc::Status{}; + }); + auto status = client_->CreateIndex(collection_name, index_desc, progress_monitor); + EXPECT_TRUE(status.IsOk()); } \ No newline at end of file