Skip to content

Commit

Permalink
Add: Support multiple vectors per key
Browse files Browse the repository at this point in the history
  • Loading branch information
ashvardanian committed Aug 14, 2023
1 parent 27b214f commit 7e5f6a7
Show file tree
Hide file tree
Showing 5 changed files with 91 additions and 74 deletions.
5 changes: 1 addition & 4 deletions .gitmodules
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,4 @@
url = https://github.com/ashvardanian/simsimd
[submodule "fp16"]
path = fp16
url = https://github.com/maratyszcza/fp16
[submodule "robin-map"]
path = robin-map
url = https://github.com/tessil/robin-map
url = https://github.com/maratyszcza/fp16
36 changes: 19 additions & 17 deletions cpp/test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,22 +39,13 @@ void test_cosine(index_at& index, std::vector<std::vector<scalar_at>> const& vec
index.reserve(10);
index.add(key_first, vector_first, args...);

if constexpr (punned_ak) {
auto result = index.add(key_first, vector_first, args...);
expect(!!result == index.multi());
result.error.release();

std::size_t first_key_count = index.count(key_first);
expect(first_key_count == (1ul + index.multi()));
}

// Default approximate search
key_t matched_labels[10] = {0};
key_t matched_keys[10] = {0};
distance_t matched_distances[10] = {0};
std::size_t matched_count = index.search(vector_first, 5, args...).dump_to(matched_labels, matched_distances);
std::size_t matched_count = index.search(vector_first, 5, args...).dump_to(matched_keys, matched_distances);

expect(matched_count == 1);
expect(matched_labels[0] == key_first);
expect(matched_keys[0] == key_first);
expect(std::abs(matched_distances[0]) < 0.01);

// Add more entries
Expand All @@ -63,7 +54,7 @@ void test_cosine(index_at& index, std::vector<std::vector<scalar_at>> const& vec
expect(index.size() == 3);

// Perform exact search
matched_count = index.search(vector_first, 5, args...).dump_to(matched_labels, matched_distances);
matched_count = index.search(vector_first, 5, args...).dump_to(matched_keys, matched_distances);

// Validate scans
std::size_t count = 0;
Expand Down Expand Up @@ -91,9 +82,9 @@ void test_cosine(index_at& index, std::vector<std::vector<scalar_at>> const& vec
// Search again over reconstructed index
index.save("tmp.usearch");
index.load("tmp.usearch");
matched_count = index.search(vector_first, 5, args...).dump_to(matched_labels, matched_distances);
matched_count = index.search(vector_first, 5, args...).dump_to(matched_keys, matched_distances);
expect(matched_count == 3);
expect(matched_labels[0] == key_first);
expect(matched_keys[0] == key_first);
expect(std::abs(matched_distances[0]) < 0.01);

if constexpr (punned_ak) {
Expand All @@ -115,13 +106,24 @@ void test_cosine(index_at& index, std::vector<std::vector<scalar_at>> const& vec
}
});

// Check for duplicates
if constexpr (punned_ak) {
index.reserve({vectors.size() + 1u, executor.size()});
auto result = index.add(key_first, vector_first, args...);
expect(!!result == index.multi());
result.error.release();

std::size_t first_key_count = index.count(key_first);
expect(first_key_count == (1ul + index.multi()));
}

// Search again over mapped index
// file_head_result_t head = index_dense_metadata("tmp.usearch");
// expect(head.size == 3);
index.view("tmp.usearch");
matched_count = index.search(vector_first, 5, args...).dump_to(matched_labels, matched_distances);
matched_count = index.search(vector_first, 5, args...).dump_to(matched_keys, matched_distances);
expect(matched_count == 3);
expect(matched_labels[0] == key_first);
expect(matched_keys[0] == key_first);
expect(std::abs(matched_distances[0]) < 0.01);

if constexpr (punned_ak) {
Expand Down
4 changes: 3 additions & 1 deletion include/usearch/index.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -885,7 +885,9 @@ class ring_gt {
bool reserve(std::size_t n) noexcept {
if (n < size())
return false; // prevent data loss
n = (std::max<std::size_t>)(n, 64u);
if (n <= capacity())
return true;
n = (std::max<std::size_t>)(ceil2(n), 64u);
element_t* elements = allocator_.allocate(n);
if (!elements)
return false;
Expand Down
114 changes: 67 additions & 47 deletions include/usearch/index_dense.hpp
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
#pragma once
#include <stdlib.h> // `aligned_alloc`

#include <functional> // `std::function`
#include <numeric> // `std::iota`
#include <shared_mutex> // `std::shared_mutex`
#include <thread> // `std::thread`
#include <vector> // `std::vector`
#include <functional> // `std::function`
#include <numeric> // `std::iota`
#include <shared_mutex> // `std::shared_mutex`
#include <thread> // `std::thread`
#include <unordered_set> // `std::unordered_multiset`
#include <vector> // `std::vector`

#include <usearch/index.hpp>
#include <usearch/index_plugins.hpp>
Expand Down Expand Up @@ -323,8 +324,13 @@ class index_dense_gt {
struct key_and_slot_t {
key_t key;
compressed_slot_t slot;

bool any_slot() const { return slot == default_free_value<compressed_slot_t>(); }
static key_and_slot_t any_slot(key_t key) { return {key, default_free_value<compressed_slot_t>()}; }
};

key_and_slot_t key_and_any_slot() {}

struct lookup_key_hash_t {
using is_transparent = void;
std::size_t operator()(key_and_slot_t const& k) const noexcept { return std::hash<key_t>{}(k.key); }
Expand All @@ -336,12 +342,12 @@ class index_dense_gt {
bool operator()(key_and_slot_t const& a, key_t const& b) const noexcept { return a.key == b; }
bool operator()(key_t const& a, key_and_slot_t const& b) const noexcept { return a == b.key; }
bool operator()(key_and_slot_t const& a, key_and_slot_t const& b) const noexcept {
return a.key == b.key && a.slot == b.slot;
return (!a.any_slot() & !b.any_slot()) ? a.key == b.key && a.slot == b.slot : a.key == b.key;
}
};

/// @brief Multi-Map from keys to IDs, and allocated vectors.
tsl::robin_set<key_and_slot_t, lookup_key_hash_t, lookup_key_same_t> slot_lookup_;
std::unordered_multiset<key_and_slot_t, lookup_key_hash_t, lookup_key_same_t> slot_lookup_;

/// @brief Mutex, controlling concurrent access to `slot_lookup_`.
mutable shared_mutex_t slot_lookup_mutex_;
Expand Down Expand Up @@ -855,7 +861,7 @@ class index_dense_gt {
*/
bool contains(key_t key) const {
shared_lock_t lock(slot_lookup_mutex_);
return slot_lookup_.contains(key);
return slot_lookup_.find(key_and_slot_t::any_slot(key)) != slot_lookup_.end();
}

/**
Expand All @@ -864,7 +870,7 @@ class index_dense_gt {
*/
std::size_t count(key_t key) const {
shared_lock_t lock(slot_lookup_mutex_);
return slot_lookup_.count(key);
return slot_lookup_.count(key_and_slot_t::any_slot(key));
}

struct labeling_result_t {
Expand All @@ -890,65 +896,70 @@ class index_dense_gt {
labeling_result_t result;

unique_lock_t lookup_lock(slot_lookup_mutex_);
auto labeled_iterator = slot_lookup_.find(key);
if (labeled_iterator == slot_lookup_.end())
auto matching_slots = slot_lookup_.equal_range(key_and_slot_t::any_slot(key));
if (matching_slots.first == matching_slots.second)
return result;

// Grow the removed entries ring, if needed
std::size_t matching_count = std::distance(matching_slots.first, matching_slots.second);
std::unique_lock<std::mutex> free_lock(free_keys_mutex_);
if (free_keys_.size() == free_keys_.capacity())
if (!free_keys_.reserve((std::max<std::size_t>)(free_keys_.capacity() * 2, 64ul)))
return result.failed("Can't allocate memory for a free-list");
if (!free_keys_.reserve(free_keys_.size() + matching_count))
return result.failed("Can't allocate memory for a free-list");

// A removed entry would be:
// - present in `free_keys_`
// - missing in the `slot_lookup_`
// - marked in the `typed_` index with a `free_key_`
compressed_slot_t slot = (*labeled_iterator).slot;
free_keys_.push(slot);
slot_lookup_.erase(labeled_iterator);
typed_->at(slot).key = free_key_;
result.completed = true;
for (auto slots_it = matching_slots.first; slots_it != matching_slots.second; ++slots_it) {
compressed_slot_t slot = (*slots_it).slot;
free_keys_.push(slot);
typed_->at(slot).key = free_key_;
}
slot_lookup_.erase(matching_slots.first, matching_slots.second);
result.completed = matching_count;

return result;
}

/**
* @brief Removes multiple entries with the specified keys from the index.
* @param[in] labels_begin The beginning of the keys range.
* @param[in] labels_end The ending of the keys range.
* @param[in] keys_begin The beginning of the keys range.
* @param[in] keys_end The ending of the keys range.
* @return The ::labeling_result_t indicating the result of the removal operation.
* `result.completed` will contain the number of keys that were successfully removed.
* `result.error` will contain an error message if an error occurred during the removal operation.
*/
template <typename labels_iterator_at>
labeling_result_t remove(labels_iterator_at&& labels_begin, labels_iterator_at&& labels_end) {
labeling_result_t remove(labels_iterator_at keys_begin, labels_iterator_at keys_end) {

labeling_result_t result;
unique_lock_t lookup_lock(slot_lookup_mutex_);
std::unique_lock<std::mutex> free_lock(free_keys_mutex_);

// Grow the removed entries ring, if needed
std::size_t count_requests = std::distance(labels_begin, labels_end);
if (!free_keys_.reserve(free_keys_.size() + count_requests))
std::size_t matching_count = 0;
for (auto keys_it = keys_begin; keys_it != keys_end; ++keys_it)
matching_count += slot_lookup_.count(key_and_slot_t::any_slot(*keys_it));

if (!free_keys_.reserve(free_keys_.size() + matching_count))
return result.failed("Can't allocate memory for a free-list");

// Remove them one-by-one
for (auto label_it = labels_begin; label_it != labels_end; ++label_it) {
key_t key = *label_it;
auto labeled_iterator = slot_lookup_.find(key);
if (labeled_iterator == slot_lookup_.end())
continue;

for (auto keys_it = keys_begin; keys_it != keys_end; ++keys_it) {
key_t key = *keys_it;
auto matching_slots = slot_lookup_.equal_range(key_and_slot_t::any_slot(key));
// A removed entry would be:
// - present in `free_keys_`
// - missing in the `slot_lookup_`
// - marked in the `typed_` index with a `free_key_`
compressed_slot_t slot = (*labeled_iterator).slot;
free_keys_.push(slot);
slot_lookup_.erase(labeled_iterator);
typed_->at(slot).key = free_key_;
result.completed += 1;
for (auto slots_it = matching_slots.first; slots_it != matching_slots.second; ++slots_it) {
compressed_slot_t slot = (*slots_it).slot;
free_keys_.push(slot);
typed_->at(slot).key = free_key_;
}

matching_count = std::distance(matching_slots.first, matching_slots.second);
slot_lookup_.erase(matching_slots.first, matching_slots.second);
result.completed += matching_count;
}

return result;
Expand All @@ -965,16 +976,24 @@ class index_dense_gt {
labeling_result_t rename(key_t from, key_t to) {
labeling_result_t result;
unique_lock_t lookup_lock(slot_lookup_mutex_);
auto labeled_iterator = slot_lookup_.find(from);
if (labeled_iterator == slot_lookup_.end())
return result;

compressed_slot_t slot = (*labeled_iterator).slot;
key_and_slot_t key_and_slot{to, slot};
slot_lookup_.erase(labeled_iterator);
slot_lookup_.insert(key_and_slot);
typed_->at(slot).key = to;
result.completed = true;
if (!multi() && slot_lookup_.count(key_and_slot_t::any_slot(to)))
return result.failed("Renaming impossible, the key is already in use");

// The `from` may map to multiple entries
while (true) {
auto slots_it = slot_lookup_.find(key_and_slot_t::any_slot(from));
if (slots_it == slot_lookup_.end())
break;

compressed_slot_t slot = (*slots_it).slot;
key_and_slot_t key_and_slot{to, slot};
slot_lookup_.erase(slots_it);
slot_lookup_.insert(key_and_slot);
typed_->at(slot).key = to;
++result.completed;
}

return result;
}

Expand Down Expand Up @@ -1328,7 +1347,7 @@ class index_dense_gt {
// Find the matching ID
{
shared_lock_t lock(slot_lookup_mutex_);
auto it = slot_lookup_.find(key);
auto it = slot_lookup_.find(key_and_slot_t::any_slot(key));
if (it == slot_lookup_.end())
return false;
slot = (*it).slot;
Expand All @@ -1341,10 +1360,11 @@ class index_dense_gt {
return true;
} else {
shared_lock_t lock(slot_lookup_mutex_);
auto equal_range_pair = slot_lookup_.equal_range(key);
auto equal_range_pair = slot_lookup_.equal_range(key_and_slot_t::any_slot(key));
std::size_t count_exported = 0;
for (auto begin = equal_range_pair.first;
begin != equal_range_pair.second && count_exported != vectors_limit; ++begin, ++count_exported) {
//
compressed_slot_t slot = (*begin).slot;
byte_t const* punned_vector = reinterpret_cast<byte_t const*>(vectors_lookup_[slot]);
byte_t* reconstructed_vector = (byte_t*)reconstructed + metric_.bytes_per_vector() * count_exported;
Expand Down
6 changes: 1 addition & 5 deletions python/scripts/test_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,11 +114,7 @@ def test_index_duplicates(batch_size):
assert len(index) == batch_size * 2

two_per_key = index.get(keys)
print(two_per_key)
if batch_size == 1:
assert two_per_key.shape == (2, ndim)
else:
assert np.vstack(two_per_key).shape == (2 * batch_size, ndim)
assert np.vstack(two_per_key).shape == (2 * batch_size, ndim)


@pytest.mark.parametrize("batch_size", [1, 7, 1024])
Expand Down

0 comments on commit 7e5f6a7

Please sign in to comment.