Skip to content

Commit

Permalink
feat!: fp16/bf16 vector
Browse files Browse the repository at this point in the history
  • Loading branch information
jiangyinzuo committed Dec 13, 2024
1 parent 0c1a669 commit c9993a9
Show file tree
Hide file tree
Showing 22 changed files with 874 additions and 191 deletions.
8 changes: 2 additions & 6 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ project(milvus_sdk LANGUAGES CXX)

set(CMAKE_VERBOSE_MAKEFILE OFF)
set(CMAKE_EXPORT_COMPILE_COMMANDS ON)
set(CMAKE_CXX_STANDARD 14)
set(CMAKE_CXX_STANDARD 17)
set(CMAKE_CXX_STANDARD_REQUIRED ON)

set(CMAKE_MODULE_PATH ${CMAKE_MODULE_PATH} "${CMAKE_CURRENT_SOURCE_DIR}/cmake")
Expand Down Expand Up @@ -65,13 +65,9 @@ define_option_string(MILVUS_WITH_GRPC "Using gRPC from" "module"
define_option_string(MILVUS_WITH_ZLIB "Using Zlib from" "module" "package" "module")
define_option_string(MILVUS_WITH_NLOHMANN_JSON "nlohmann json from" "module" "package" "module")
define_option_string(MILVUS_WITH_GTEST "Using GTest from" "module" "package" "module")
define_option_string(MILVUS_WITH_EIGEN "Using Eigen from" "module" "package" "module")


set(CMAKE_CXX_STANDARD 14)
set(CMAKE_CXX_STANDARD_REQUIRED on)
set(BUILD_SCRIPTS_DIR ${PROJECT_SOURCE_DIR}/scripts)
set(CMAKE_EXPORT_COMPILE_COMMANDS ON)


# load third packages and milvus-proto
include(ThirdPartyPackages)
Expand Down
4 changes: 2 additions & 2 deletions DEVELOPMENT.md
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,8 @@ Or `make all-release` to build the release version.
And you could also create a dedicated CMake build directory, then use CMake to build it from the source by yourself

```shell
$ mkdir build
$ cd build
$ mkdir cmake_build
$ cd cmake_build
$ cmake ..
$ make
```
Expand Down
20 changes: 20 additions & 0 deletions cmake/ThirdPartyPackages.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -27,22 +27,31 @@ FetchContent_Declare(
grpc
GIT_REPOSITORY https://github.com/grpc/grpc.git
GIT_TAG v${GRPC_VERSION}
GIT_SHALLOW TRUE
)

# nlohmann_json
FetchContent_Declare(
nlohmann_json
GIT_REPOSITORY https://github.com/nlohmann/json.git
GIT_TAG v${NLOHMANN_JSON_VERSION}
GIT_SHALLOW TRUE
)

# googletest
FetchContent_Declare(
googletest
GIT_REPOSITORY https://github.com/google/googletest.git
GIT_TAG release-${GOOGLETEST_VERSION}
GIT_SHALLOW TRUE
)

FetchContent_Declare(
eigen3
GIT_REPOSITORY https://gitlab.com/libeigen/eigen.git
GIT_TAG 3.4.0
GIT_SHALLOW TRUE
)

# grpc
if ("${MILVUS_WITH_GRPC}" STREQUAL "pakcage")
Expand Down Expand Up @@ -79,3 +88,14 @@ else ()
add_subdirectory(${nlohmann_json_SOURCE_DIR} ${nlohmann_json_BINARY_DIR} EXCLUDE_FROM_ALL)
endif ()
endif ()

# eigen3
if ("${MILVUS_WITH_EIGEN}" STREQUAL "package")
find_package(Eigen3 REQUIRED NO_MODULE)
else ()
if (NOT eigen3_POPULATED)
FetchContent_Populate(eigen3)
set(BUILD_TESTING OFF CACHE INTERNAL "")
add_subdirectory(${eigen3_SOURCE_DIR} ${eigen3_BINARY_DIR} EXCLUDE_FROM_ALL)
endif ()
endif ()
2 changes: 1 addition & 1 deletion examples/simple/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ main(int argc, char* argv[]) {
std::uniform_int_distribution<int64_t> int64_gen(0, row_count - 1);
int64_t q_number = int64_gen(ran);
std::vector<float> q_vector = insert_vectors[q_number];
arguments.AddTargetVector(field_face_name, std::move(q_vector));
arguments.AddTargetVector<milvus::FloatVecFieldData>(field_face_name, std::move(q_vector));
std::cout << "Searching the No." << q_number << " entity..." << std::endl;

milvus::SearchResults search_results{};
Expand Down
4 changes: 2 additions & 2 deletions src/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ add_library(milvus_sdk ${impl_files} ${impl_types_files})
# add proto gens
add_milvus_protos(milvus_sdk)
set_target_properties(milvus_sdk PROPERTIES OUTPUT_NAME milvus_sdk)
target_link_libraries(milvus_sdk gRPC::grpc++ nlohmann_json::nlohmann_json)
target_link_libraries(milvus_sdk gRPC::grpc++ nlohmann_json::nlohmann_json Eigen3::Eigen)
target_include_directories(milvus_sdk PUBLIC include)


Expand All @@ -37,4 +37,4 @@ install(DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}/include/
install(TARGETS milvus_sdk
EXPORT milvus_sdk
LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR}
ARCHIVE DESTINATION ${CMAKE_INSTALL_LIBDIR})
ARCHIVE DESTINATION ${CMAKE_INSTALL_LIBDIR})
80 changes: 61 additions & 19 deletions src/impl/MilvusClientImpl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -666,6 +666,51 @@ MilvusClientImpl::Delete(const std::string& collection_name, const std::string&
post);
}

template <DataType Dt>
struct SearchTrait;

template <>
struct SearchTrait<DataType::BINARY_VECTOR> {
static constexpr proto::common::PlaceholderType type = proto::common::PlaceholderType::BinaryVector;
using FieldDataType = BinaryVecFieldData;
static constexpr size_t size_of_element = sizeof(uint8_t);
};

template <>
struct SearchTrait<DataType::FLOAT_VECTOR> {
static constexpr proto::common::PlaceholderType type = proto::common::PlaceholderType::FloatVector;
using FieldDataType = FloatVecFieldData;
static constexpr size_t size_of_element = sizeof(float);
};

template <>
struct SearchTrait<DataType::FLOAT16_VECTOR> {
static constexpr proto::common::PlaceholderType type = proto::common::PlaceholderType::Float16Vector;
using FieldDataType = Float16VecFieldData;
static constexpr size_t size_of_element = sizeof(Eigen::half);
static_assert(sizeof(Eigen::half) == 2, "Eigen::half size is not 2");
};

template <>
struct SearchTrait<DataType::BFLOAT16_VECTOR> {
static constexpr proto::common::PlaceholderType type = proto::common::PlaceholderType::BFloat16Vector;
using FieldDataType = BFloat16VecFieldData;
static constexpr size_t size_of_element = sizeof(Eigen::bfloat16);
static_assert(sizeof(Eigen::bfloat16) == 2, "Eigen::bfloat16 size is not 2");
};

template <DataType Dt>
static void
SetPlaceHolderValue(proto::common::PlaceholderValue& placeholder_value, FieldDataPtr target) {
placeholder_value.set_type(SearchTrait<Dt>::type);
auto& vec = dynamic_cast<typename SearchTrait<Dt>::BinaryVecFieldData&>(*target);
for (const auto& v : vec.Data()) {
std::string placeholder_data(reinterpret_cast<const char*>(v.data()),
v.size() * SearchTrait<Dt>::size_of_element);
placeholder_value.add_values(std::move(placeholder_data));
}
}

Status
MilvusClientImpl::Search(const SearchArguments& arguments, SearchResults& results, int timeout) {
std::string anns_field;
Expand Down Expand Up @@ -706,23 +751,21 @@ MilvusClientImpl::Search(const SearchArguments& arguments, SearchResults& result
auto& placeholder_value = *placeholder_group.add_placeholders();
placeholder_value.set_tag("$0");
auto target = arguments.TargetVectors();
if (target->Type() == DataType::BINARY_VECTOR) {
// bins
placeholder_value.set_type(proto::common::PlaceholderType::BinaryVector);
auto& bins_vec = dynamic_cast<BinaryVecFieldData&>(*target);
for (const auto& bins : bins_vec.Data()) {
std::string placeholder_data(reinterpret_cast<const char*>(bins.data()), bins.size());
placeholder_value.add_values(std::move(placeholder_data));
}
} else {
// floats
placeholder_value.set_type(proto::common::PlaceholderType::FloatVector);
auto& floats_vec = dynamic_cast<FloatVecFieldData&>(*target);
for (const auto& floats : floats_vec.Data()) {
std::string placeholder_data(reinterpret_cast<const char*>(floats.data()),
floats.size() * sizeof(float));
placeholder_value.add_values(std::move(placeholder_data));
}
switch (target->Type()) {
case DataType::BINARY_VECTOR:
SetPlaceHolderValue<DataType::BINARY_VECTOR>(placeholder_value, target);
break;
case DataType::FLOAT_VECTOR:
SetPlaceHolderValue<DataType::FLOAT_VECTOR>(placeholder_value, target);
break;
case DataType::FLOAT16_VECTOR:
SetPlaceHolderValue<DataType::FLOAT16_VECTOR>(placeholder_value, target);
break;
case DataType::BFLOAT16_VECTOR:
SetPlaceHolderValue<DataType::BFLOAT16_VECTOR>(placeholder_value, target);
break;
default:
assert(false);
}
rpc_request.set_placeholder_group(std::move(placeholder_group.SerializeAsString()));

Expand Down Expand Up @@ -839,8 +882,7 @@ MilvusClientImpl::CalcDistance(const CalcDistanceArguments& arguments, DistanceA

if (arg_vectors->Type() == DataType::FLOAT_VECTOR) {
FloatVecFieldDataPtr data_ptr = std::static_pointer_cast<FloatVecFieldData>(arg_vectors);
auto float_vectors = data_array->mutable_float_vector();
auto mutable_data = float_vectors->mutable_data();
auto mutable_data = data_array->mutable_float_vector()->mutable_data();
auto& vectors = data_ptr->Data();
for (auto& vector : vectors) {
mutable_data->Add(vector.begin(), vector.end());
Expand Down
Loading

0 comments on commit c9993a9

Please sign in to comment.