Skip to content

Commit

Permalink
Merge pull request ClickHouse#54103 from ClickHouse/ustweaks
Browse files Browse the repository at this point in the history
Small usearch index improvements: metrics and configurable internal data type
  • Loading branch information
rschu1ze committed Sep 14, 2023
2 parents f6d5ed2 + e018f1d commit f0eadd4
Show file tree
Hide file tree
Showing 8 changed files with 140 additions and 36 deletions.
9 changes: 6 additions & 3 deletions docs/en/engines/table-engines/mergetree-family/annindexes.md
Original file line number Diff line number Diff line change
Expand Up @@ -252,7 +252,7 @@ CREATE TABLE table_with_usearch_index
(
id Int64,
vectors Array(Float32),
INDEX [ann_index_name] vectors TYPE usearch([Distance]) [GRANULARITY N]
INDEX [ann_index_name] vectors TYPE usearch([Distance[, ScalarKind]]) [GRANULARITY N]
)
ENGINE = MergeTree
ORDER BY id;
Expand All @@ -265,7 +265,7 @@ CREATE TABLE table_with_usearch_index
(
id Int64,
vectors Tuple(Float32[, Float32[, ...]]),
INDEX [ann_index_name] vectors TYPE usearch([Distance]) [GRANULARITY N]
INDEX [ann_index_name] vectors TYPE usearch([Distance[, ScalarKind]]) [GRANULARITY N]
)
ENGINE = MergeTree
ORDER BY id;
Expand All @@ -277,5 +277,8 @@ USearch currently supports two distance functions:
- `cosineDistance`, also called cosine similarity, is the cosine of the angle between two (non-zero) vectors
([Wikipedia](https://en.wikipedia.org/wiki/Cosine_similarity)).
USearch allows storing the vectors in reduced precision formats. Supported scalar kinds are `f64`, `f32`, `f16` or `i8`. If no scalar kind
was specified during index creation, `f16` is used as default.
For normalized data, `L2Distance` is usually a better choice, otherwise `cosineDistance` is recommended to compensate for scale. If no
distance function was specified during index creation, `L2Distance` is used as default.
distance function was specified during index creation, `L2Distance` is used as default.
7 changes: 7 additions & 0 deletions src/Common/ProfileEvents.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -253,6 +253,13 @@ The server successfully detected this situation and will download merged part fr
M(PolygonsAddedToPool, "A polygon has been added to the cache (pool) for the 'pointInPolygon' function.") \
M(PolygonsInPoolAllocatedBytes, "The number of bytes for polygons added to the cache (pool) for the 'pointInPolygon' function.") \
\
M(USearchAddCount, "Number of vectors added to usearch indexes.") \
M(USearchAddVisitedMembers, "Number of nodes visited when adding vectors to usearch indexes.") \
M(USearchAddComputedDistances, "Number of times distance was computed when adding vectors to usearch indexes.") \
M(USearchSearchCount, "Number of search operations performed in usearch indexes.") \
M(USearchSearchVisitedMembers, "Number of nodes visited when searching in usearch indexes.") \
M(USearchSearchComputedDistances, "Number of times distance was computed when searching usearch indexes.") \
\
M(RWLockAcquiredReadLocks, "Number of times a read lock was acquired (in a heavy RWLock).") \
M(RWLockAcquiredWriteLocks, "Number of times a write lock was acquired (in a heavy RWLock).") \
M(RWLockReadersWaitMilliseconds, "Total time spent waiting for a read lock to be acquired (in a heavy RWLock).") \
Expand Down
3 changes: 3 additions & 0 deletions src/Storages/MergeTree/MergeTreeIndexAnnoy.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,9 @@ void MergeTreeIndexAggregatorAnnoy<Distance>::update(const Block & block, size_t
if (rows_read == 0)
return;

if (rows_read > std::numeric_limits<uint32_t>::max())
throw Exception(ErrorCodes::INCORRECT_DATA, "Index granularity is too big: more than 4B rows per index granule.");

if (index_sample_block.columns() > 1)
throw Exception(ErrorCodes::LOGICAL_ERROR, "Expected block with single column");

Expand Down
114 changes: 91 additions & 23 deletions src/Storages/MergeTree/MergeTreeIndexUSearch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,16 @@
#include <Interpreters/Context.h>
#include <Interpreters/castColumn.h>

namespace ProfileEvents
{
extern const Event USearchAddCount;
extern const Event USearchAddVisitedMembers;
extern const Event USearchAddComputedDistances;
extern const Event USearchSearchCount;
extern const Event USearchSearchVisitedMembers;
extern const Event USearchSearchComputedDistances;
}

namespace DB
{

Expand All @@ -28,9 +38,20 @@ namespace ErrorCodes
extern const int LOGICAL_ERROR;
}

namespace
{

std::unordered_map<String, unum::usearch::scalar_kind_t> nameToScalarKind = {
{"f64", unum::usearch::scalar_kind_t::f64_k},
{"f32", unum::usearch::scalar_kind_t::f32_k},
{"f16", unum::usearch::scalar_kind_t::f16_k},
{"i8", unum::usearch::scalar_kind_t::i8_k}};

}

template <unum::usearch::metric_kind_t Metric>
USearchIndexWithSerialization<Metric>::USearchIndexWithSerialization(size_t dimensions)
: Base(Base::make(unum::usearch::metric_punned_t(dimensions, Metric)))
USearchIndexWithSerialization<Metric>::USearchIndexWithSerialization(size_t dimensions, unum::usearch::scalar_kind_t scalar_kind)
: Base(Base::make(unum::usearch::metric_punned_t(dimensions, Metric, scalar_kind)))
{
}

Expand Down Expand Up @@ -67,9 +88,11 @@ size_t USearchIndexWithSerialization<Metric>::getDimensions() const
template <unum::usearch::metric_kind_t Metric>
MergeTreeIndexGranuleUSearch<Metric>::MergeTreeIndexGranuleUSearch(
const String & index_name_,
const Block & index_sample_block_)
const Block & index_sample_block_,
unum::usearch::scalar_kind_t scalar_kind_)
: index_name(index_name_)
, index_sample_block(index_sample_block_)
, scalar_kind(scalar_kind_)
, index(nullptr)
{
}
Expand All @@ -78,9 +101,11 @@ template <unum::usearch::metric_kind_t Metric>
MergeTreeIndexGranuleUSearch<Metric>::MergeTreeIndexGranuleUSearch(
const String & index_name_,
const Block & index_sample_block_,
unum::usearch::scalar_kind_t scalar_kind_,
USearchIndexWithSerializationPtr<Metric> index_)
: index_name(index_name_)
, index_sample_block(index_sample_block_)
, scalar_kind(scalar_kind_)
, index(std::move(index_))
{
}
Expand All @@ -99,23 +124,25 @@ void MergeTreeIndexGranuleUSearch<Metric>::deserializeBinary(ReadBuffer & istr,
{
UInt64 dimension;
readIntBinary(dimension, istr);
index = std::make_shared<USearchIndexWithSerialization<Metric>>(dimension);
index = std::make_shared<USearchIndexWithSerialization<Metric>>(dimension, scalar_kind);
index->deserialize(istr);
}

template <unum::usearch::metric_kind_t Metric>
MergeTreeIndexAggregatorUSearch<Metric>::MergeTreeIndexAggregatorUSearch(
const String & index_name_,
const Block & index_sample_block_)
const Block & index_sample_block_,
unum::usearch::scalar_kind_t scalar_kind_)
: index_name(index_name_)
, index_sample_block(index_sample_block_)
, scalar_kind(scalar_kind_)
{
}

template <unum::usearch::metric_kind_t Metric>
MergeTreeIndexGranulePtr MergeTreeIndexAggregatorUSearch<Metric>::getGranuleAndReset()
{
auto granule = std::make_shared<MergeTreeIndexGranuleUSearch<Metric>>(index_name, index_sample_block, index);
auto granule = std::make_shared<MergeTreeIndexGranuleUSearch<Metric>>(index_name, index_sample_block, scalar_kind, index);
index = nullptr;
return granule;
}
Expand All @@ -131,9 +158,13 @@ void MergeTreeIndexAggregatorUSearch<Metric>::update(const Block & block, size_t
block.rows());

size_t rows_read = std::min(limit, block.rows() - *pos);

if (rows_read == 0)
return;

if (rows_read > std::numeric_limits<uint32_t>::max())
throw Exception(ErrorCodes::INCORRECT_DATA, "Index granularity is too big: more than 4B rows per index granule.");

if (index_sample_block.columns() > 1)
throw Exception(ErrorCodes::LOGICAL_ERROR, "Expected block with single column");

Expand All @@ -151,26 +182,29 @@ void MergeTreeIndexAggregatorUSearch<Metric>::update(const Block & block, size_t
const auto & offsets = column_array->getOffsets();
const size_t num_rows = offsets.size();


/// Check all sizes are the same
size_t size = offsets[0];
for (size_t i = 0; i < num_rows - 1; ++i)
if (offsets[i + 1] - offsets[i] != size)
throw Exception(ErrorCodes::INCORRECT_DATA, "All arrays in column {} must have equal length", index_column_name);

if (!index)
index = std::make_shared<USearchIndexWithSerialization<Metric>>(size);
index = std::make_shared<USearchIndexWithSerialization<Metric>>(size, scalar_kind);

/// Add all rows of block
if (!index->reserve(unum::usearch::ceil2(index->size() + num_rows)))
throw Exception(ErrorCodes::CANNOT_ALLOCATE_MEMORY, "Could not reserve memory for usearch index");

if (auto rc = index->add(index->size(), array.data()); !rc)
throw Exception(ErrorCodes::INCORRECT_DATA, rc.error.release());
for (size_t current_row = 1; current_row < num_rows; ++current_row)
if (auto rc = index->add(index->size(), &array[offsets[current_row - 1]]); !rc)
for (size_t current_row = 0; current_row < num_rows; ++current_row)
{
auto rc = index->add(static_cast<uint32_t>(index->size()), &array[offsets[current_row - 1]]);
if (!rc)
throw Exception(ErrorCodes::INCORRECT_DATA, rc.error.release());

ProfileEvents::increment(ProfileEvents::USearchAddCount);
ProfileEvents::increment(ProfileEvents::USearchAddVisitedMembers, rc.visited_members);
ProfileEvents::increment(ProfileEvents::USearchAddComputedDistances, rc.computed_distances);
}
}
else if (const auto & column_tuple = typeid_cast<const ColumnTuple *>(column_cut.get()))
{
Expand All @@ -187,14 +221,21 @@ void MergeTreeIndexAggregatorUSearch<Metric>::update(const Block & block, size_t
throw Exception(ErrorCodes::LOGICAL_ERROR, "Tuple has 0 rows, {} rows expected", rows_read);

if (!index)
index = std::make_shared<USearchIndexWithSerialization<Metric>>(data[0].size());
index = std::make_shared<USearchIndexWithSerialization<Metric>>(data[0].size(), scalar_kind);

if (!index->reserve(unum::usearch::ceil2(index->size() + data.size())))
throw Exception(ErrorCodes::CANNOT_ALLOCATE_MEMORY, "Could not reserve memory for usearch index");

for (const auto & item : data)
if (auto rc = index->add(index->size(), item.data()); !rc)
{
auto rc = index->add(static_cast<uint32_t>(index->size()), item.data());
if (!rc)
throw Exception(ErrorCodes::INCORRECT_DATA, rc.error.release());

ProfileEvents::increment(ProfileEvents::USearchAddCount);
ProfileEvents::increment(ProfileEvents::USearchAddVisitedMembers, rc.visited_members);
ProfileEvents::increment(ProfileEvents::USearchAddComputedDistances, rc.computed_distances);
}
}
else
throw Exception(ErrorCodes::LOGICAL_ERROR, "Expected Array or Tuple column");
Expand Down Expand Up @@ -257,7 +298,12 @@ std::vector<size_t> MergeTreeIndexConditionUSearch::getUsefulRangesImpl(MergeTre
ann_condition.getDimensions(), index->dimensions());

auto result = index->search(reference_vector.data(), limit);
std::vector<UInt64> neighbors(result.size()); /// indexes of dots which were closest to the reference vector

ProfileEvents::increment(ProfileEvents::USearchSearchCount);
ProfileEvents::increment(ProfileEvents::USearchSearchVisitedMembers, result.visited_members);
ProfileEvents::increment(ProfileEvents::USearchSearchComputedDistances, result.computed_distances);

std::vector<UInt32> neighbors(result.size()); /// indexes of dots which were closest to the reference vector
std::vector<Float32> distances(result.size());
result.dump_to(neighbors.data(), distances.data());

Expand All @@ -277,27 +323,28 @@ std::vector<size_t> MergeTreeIndexConditionUSearch::getUsefulRangesImpl(MergeTre
return granule_numbers;
}

MergeTreeIndexUSearch::MergeTreeIndexUSearch(const IndexDescription & index_, const String & distance_function_)
MergeTreeIndexUSearch::MergeTreeIndexUSearch(const IndexDescription & index_, const String & distance_function_, unum::usearch::scalar_kind_t scalar_kind_)
: IMergeTreeIndex(index_)
, distance_function(distance_function_)
, scalar_kind(scalar_kind_)
{
}

MergeTreeIndexGranulePtr MergeTreeIndexUSearch::createIndexGranule() const
{
if (distance_function == DISTANCE_FUNCTION_L2)
return std::make_shared<MergeTreeIndexGranuleUSearch<unum::usearch::metric_kind_t::l2sq_k>>(index.name, index.sample_block);
return std::make_shared<MergeTreeIndexGranuleUSearch<unum::usearch::metric_kind_t::l2sq_k>>(index.name, index.sample_block, scalar_kind);
else if (distance_function == DISTANCE_FUNCTION_COSINE)
return std::make_shared<MergeTreeIndexGranuleUSearch<unum::usearch::metric_kind_t::cos_k>>(index.name, index.sample_block);
return std::make_shared<MergeTreeIndexGranuleUSearch<unum::usearch::metric_kind_t::cos_k>>(index.name, index.sample_block, scalar_kind);
std::unreachable();
}

MergeTreeIndexAggregatorPtr MergeTreeIndexUSearch::createIndexAggregator() const
{
if (distance_function == DISTANCE_FUNCTION_L2)
return std::make_shared<MergeTreeIndexAggregatorUSearch<unum::usearch::metric_kind_t::l2sq_k>>(index.name, index.sample_block);
return std::make_shared<MergeTreeIndexAggregatorUSearch<unum::usearch::metric_kind_t::l2sq_k>>(index.name, index.sample_block, scalar_kind);
else if (distance_function == DISTANCE_FUNCTION_COSINE)
return std::make_shared<MergeTreeIndexAggregatorUSearch<unum::usearch::metric_kind_t::cos_k>>(index.name, index.sample_block);
return std::make_shared<MergeTreeIndexAggregatorUSearch<unum::usearch::metric_kind_t::cos_k>>(index.name, index.sample_block, scalar_kind);
std::unreachable();
}

Expand All @@ -313,18 +360,25 @@ MergeTreeIndexPtr usearchIndexCreator(const IndexDescription & index)
if (!index.arguments.empty())
distance_function = index.arguments[0].get<String>();

return std::make_shared<MergeTreeIndexUSearch>(index, distance_function);
static constexpr auto default_scalar_kind = unum::usearch::scalar_kind_t::f16_k;
auto scalar_kind = default_scalar_kind;
if (index.arguments.size() > 1)
scalar_kind = nameToScalarKind.at(index.arguments[1].get<String>());

return std::make_shared<MergeTreeIndexUSearch>(index, distance_function, scalar_kind);
}

void usearchIndexValidator(const IndexDescription & index, bool /* attach */)
{
/// Check number and type of USearch index arguments:

if (index.arguments.size() > 1)
if (index.arguments.size() > 2)
throw Exception(ErrorCodes::INCORRECT_QUERY, "USearch index must not have more than one parameters");

if (!index.arguments.empty() && index.arguments[0].getType() != Field::Types::String)
throw Exception(ErrorCodes::INCORRECT_QUERY, "Distance function argument of USearch index must be of type String");
throw Exception(ErrorCodes::INCORRECT_QUERY, "First argument of USearch index (distance function) must be of type String");
if (index.arguments.size() > 1 && index.arguments[1].getType() != Field::Types::String)
throw Exception(ErrorCodes::INCORRECT_QUERY, "Second argument of USearch index (scalar type) must be of type String");

/// Check that the index is created on a single column

Expand All @@ -340,6 +394,20 @@ void usearchIndexValidator(const IndexDescription & index, bool /* attach */)
throw Exception(ErrorCodes::INCORRECT_DATA, "USearch index only supports distance functions '{}' and '{}'", DISTANCE_FUNCTION_L2, DISTANCE_FUNCTION_COSINE);
}

/// Check that a supported kind was passed as a second argument

if (index.arguments.size() > 1 && !nameToScalarKind.contains(index.arguments[1].get<String>()))
{
String supported_kinds;
for (const auto & [name, kind] : nameToScalarKind)
{
if (!supported_kinds.empty())
supported_kinds += ", ";
supported_kinds += name;
}
throw Exception(ErrorCodes::INCORRECT_DATA, "Unrecognized scalar kind (second argument) for USearch index. Supported kinds are: {}", supported_kinds);
}

/// Check data type of indexed column:

auto throw_unsupported_underlying_column_exception = []()
Expand Down
19 changes: 12 additions & 7 deletions src/Storages/MergeTree/MergeTreeIndexUSearch.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,15 @@
namespace DB
{

using USearchImplType = unum::usearch::index_dense_gt</* key_at */ uint32_t, /* compressed_slot_at */ uint32_t>;

template <unum::usearch::metric_kind_t Metric>
class USearchIndexWithSerialization : public unum::usearch::index_dense_t
class USearchIndexWithSerialization : public USearchImplType
{
using Base = unum::usearch::index_dense_t;
using Base = USearchImplType;

public:
explicit USearchIndexWithSerialization(size_t dimensions);
USearchIndexWithSerialization(size_t dimensions, unum::usearch::scalar_kind_t scalar_kind);
void serialize(WriteBuffer & ostr) const;
void deserialize(ReadBuffer & istr);
size_t getDimensions() const;
Expand All @@ -31,8 +33,8 @@ using USearchIndexWithSerializationPtr = std::shared_ptr<USearchIndexWithSeriali
template <unum::usearch::metric_kind_t Metric>
struct MergeTreeIndexGranuleUSearch final : public IMergeTreeIndexGranule
{
MergeTreeIndexGranuleUSearch(const String & index_name_, const Block & index_sample_block_);
MergeTreeIndexGranuleUSearch(const String & index_name_, const Block & index_sample_block_, USearchIndexWithSerializationPtr<Metric> index_);
MergeTreeIndexGranuleUSearch(const String & index_name_, const Block & index_sample_block_, unum::usearch::scalar_kind_t scalar_kind_);
MergeTreeIndexGranuleUSearch(const String & index_name_, const Block & index_sample_block_, unum::usearch::scalar_kind_t scalar_kind_, USearchIndexWithSerializationPtr<Metric> index_);

~MergeTreeIndexGranuleUSearch() override = default;

Expand All @@ -43,14 +45,15 @@ struct MergeTreeIndexGranuleUSearch final : public IMergeTreeIndexGranule

const String index_name;
const Block index_sample_block;
const unum::usearch::scalar_kind_t scalar_kind;
USearchIndexWithSerializationPtr<Metric> index;
};


template <unum::usearch::metric_kind_t Metric>
struct MergeTreeIndexAggregatorUSearch final : IMergeTreeIndexAggregator
{
MergeTreeIndexAggregatorUSearch(const String & index_name_, const Block & index_sample_block);
MergeTreeIndexAggregatorUSearch(const String & index_name_, const Block & index_sample_block, unum::usearch::scalar_kind_t scalar_kind_);
~MergeTreeIndexAggregatorUSearch() override = default;

bool empty() const override { return !index || index->size() == 0; }
Expand All @@ -59,6 +62,7 @@ struct MergeTreeIndexAggregatorUSearch final : IMergeTreeIndexAggregator

const String index_name;
const Block index_sample_block;
const unum::usearch::scalar_kind_t scalar_kind;
USearchIndexWithSerializationPtr<Metric> index;
};

Expand Down Expand Up @@ -90,7 +94,7 @@ class MergeTreeIndexConditionUSearch final : public IMergeTreeIndexConditionAppr
class MergeTreeIndexUSearch : public IMergeTreeIndex
{
public:
MergeTreeIndexUSearch(const IndexDescription & index_, const String & distance_function_);
MergeTreeIndexUSearch(const IndexDescription & index_, const String & distance_function_, unum::usearch::scalar_kind_t scalar_kind_);

~MergeTreeIndexUSearch() override = default;

Expand All @@ -102,6 +106,7 @@ class MergeTreeIndexUSearch : public IMergeTreeIndex

private:
const String distance_function;
const unum::usearch::scalar_kind_t scalar_kind;
};

}
Expand Down
Loading

0 comments on commit f0eadd4

Please sign in to comment.