Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature: Progress tracking and conditional termination from Python #206 #211

Merged
merged 19 commits into from
Sep 5, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
19 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions cpp/bench.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -372,6 +372,7 @@ static void single_shot(dataset_at& dataset, index_at& index, bool construct = t
executor, [&](std::size_t progress, std::size_t total) {
if (progress % 1000 == 0)
printer.print(progress, total);
return true;
});
join_attempts = result.visited_members;
}
Expand Down
60 changes: 48 additions & 12 deletions include/usearch/index.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -1212,9 +1212,10 @@ struct dummy_callback_t {
*
* This is particularly helpful when handling long-running tasks, like serialization,
* saving, and loading from disk, or index-level joins.
* The reporter checks return value to continue or stop the process, `false` means need to stop.
*/
struct dummy_progress_t {
inline void operator()(std::size_t, std::size_t) const noexcept {}
inline bool operator()(std::size_t /*processed*/, std::size_t /*total*/) const noexcept { return true; }
};

/**
Expand Down Expand Up @@ -2690,22 +2691,29 @@ class index_gt {
if (!output(&header, sizeof(header)))
return result.failed("Failed to serialize the header into stream");

// Progress status
std::size_t processed = 0;
std::size_t const total = 2 * header.size;

// Export the number of levels per node
// That is both enough to estimate the overall memory consumption,
// and to be able to estimate the offsets of every entry in the file.
for (std::size_t i = 0; i != header.size; ++i) {
node_t node = node_at_(i);
level_t level = node.level();
if (!output(&level, sizeof(level)))
return result.failed("Failed to serialize nodes levels into stream");
return result.failed("Failed to serialize into stream");
if (!progress(++processed, total))
return result.failed("Terminated by user");
}

// After that dump the nodes themselves
for (std::size_t i = 0; i != header.size; ++i) {
span_bytes_t node_bytes = node_bytes_(node_at_(i));
if (!output(node_bytes.data(), node_bytes.size()))
return result.failed("Failed to serialize nodes into stream");
progress(i, header.size);
return result.failed("Failed to serialize into stream");
if (!progress(++processed, total))
return result.failed("Terminated by user");
}

return {};
Expand Down Expand Up @@ -2763,7 +2771,8 @@ class index_gt {
return result.failed("Failed to pull nodes from the stream");
}
nodes_[i] = node_t{node_bytes.data()};
progress(i, header.size);
if (!progress(i + 1, header.size))
return result.failed("Terminated by user");
}
return {};
}
Expand Down Expand Up @@ -2936,7 +2945,8 @@ class index_gt {
// Rapidly address all the nodes
for (std::size_t i = 0; i != header.size; ++i) {
nodes_[i] = node_t{(byte_t*)file.data() + offsets[i]};
progress(i, header.size);
if (!progress(i + 1, header.size))
return result.failed("Terminated by user");
}
viewed_file_ = std::move(file);
return {};
Expand Down Expand Up @@ -2986,8 +2996,13 @@ class index_gt {
using slot_level_allocator_t = typename dynamic_allocator_traits_t::template rebind_alloc<slot_level_t>;
buffer_gt<slot_level_t, slot_level_allocator_t> slots_and_levels(size());

// Progress status
std::atomic<bool> do_tasks{true};
mgevor marked this conversation as resolved.
Show resolved Hide resolved
std::atomic<std::size_t> processed{0};
std::size_t const total = 3 * slots_and_levels.size();

// For every bottom level node, determine its parent cluster
executor.fixed(slots_and_levels.size(), [&](std::size_t thread_idx, std::size_t old_slot) {
executor.dynamic(slots_and_levels.size(), [&](std::size_t thread_idx, std::size_t old_slot) {
context_t& context = contexts_[thread_idx];
std::size_t cluster = search_for_one_( //
values[citerator_at(old_slot)], //
Expand All @@ -2997,7 +3012,13 @@ class index_gt {
static_cast<compressed_slot_t>(old_slot), //
static_cast<compressed_slot_t>(cluster), //
node_at_(old_slot).level()};
++processed;
if (thread_idx == 0)
do_tasks = progress(processed.load(), total);
return do_tasks.load();
});
if (!do_tasks.load())
return;

// Where the actual permutation happens:
std::sort(slots_and_levels.begin(), slots_and_levels.end(), [](slot_level_t const& a, slot_level_t const& b) {
Expand Down Expand Up @@ -3027,15 +3048,17 @@ class index_gt {
neighbor = static_cast<compressed_slot_t>(old_slot_to_new[compressed_slot_t(neighbor)]);

reordered_nodes[new_slot] = new_node;

progress(new_slot, slots_and_levels.size());
if (!progress(++processed, total))
return;
}

for (std::size_t new_slot = 0; new_slot != slots_and_levels.size(); ++new_slot) {
std::size_t old_slot = slots_and_levels[new_slot].old_slot;
slot_transition(node_at_(old_slot).ckey(), //
static_cast<compressed_slot_t>(old_slot), //
static_cast<compressed_slot_t>(new_slot));
if (!progress(++processed, total))
return;
}

nodes_ = std::move(reordered_nodes);
Expand Down Expand Up @@ -3064,9 +3087,13 @@ class index_gt {
executor_at&& executor = executor_at{}, //
progress_at&& progress = progress_at{}) noexcept {

// Progress status
std::atomic<bool> do_tasks{true};
std::atomic<std::size_t> processed{0};

// Erase all the incoming links
std::size_t nodes_count = size();
executor.fixed(nodes_count, [&](std::size_t, std::size_t node_idx) {
executor.dynamic(nodes_count, [&](std::size_t thread_idx, std::size_t node_idx) {
node_t node = node_at_(node_idx);
for (level_t level = 0; level <= node.level(); ++level) {
neighbors_ref_t neighbors = neighbors_(node, level);
Expand All @@ -3079,8 +3106,14 @@ class index_gt {
neighbors.push_back(neighbor_slot);
}
}
progress(node_idx, nodes_count);
++processed;
if (thread_idx == 0)
do_tasks = progress(processed.load(), nodes_count);
return do_tasks.load();
});

// At the end report the latest numbers, because the reporter thread may be finished earlier
progress(processed.load(), nodes_count);
}

private:
Expand Down Expand Up @@ -3710,7 +3743,10 @@ static join_result_t join( //
passed_rounds = ++rounds;
total_rounds = passed_rounds + free_men.size();
}
progress(passed_rounds, total_rounds);
if (thread_idx == 0 && !progress(passed_rounds, total_rounds)) {
atomic_error.store("Terminated by user");
break;
}
while (men_locks.atomic_set(free_man_slot))
;

Expand Down
63 changes: 43 additions & 20 deletions include/usearch/index_dense.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -767,7 +767,9 @@ class index_dense_gt {
* @brief Saves serialized binary index representation to a stream.
*/
template <typename output_callback_at, typename progress_at = dummy_progress_t>
serialization_result_t save_to_stream(output_callback_at&& output, serialization_config_t config = {}) const {
serialization_result_t save_to_stream(output_callback_at&& output, //
serialization_config_t config = {}, //
progress_at&& progress = {}) const {

serialization_result_t result;
std::uint64_t matrix_rows = 0;
Expand Down Expand Up @@ -831,7 +833,7 @@ class index_dense_gt {
}

// Save the actual proximity graph
return typed_->save_to_stream(std::forward<output_callback_at>(output));
return typed_->save_to_stream(std::forward<output_callback_at>(output), std::forward<progress_at>(progress));
}

/**
Expand All @@ -853,8 +855,10 @@ class index_dense_gt {
* @param[in] config Configuration parameters for imports.
* @return Outcome descriptor explicitly convertible to boolean.
*/
template <typename input_callback_at>
serialization_result_t load_from_stream(input_callback_at&& input, serialization_config_t config = {}) {
template <typename input_callback_at, typename progress_at = dummy_progress_t>
serialization_result_t load_from_stream(input_callback_at&& input, //
serialization_config_t config = {}, //
progress_at&& progress = {}) {

// Discard all previous memory allocations of `vectors_tape_allocator_`
reset();
Expand Down Expand Up @@ -915,7 +919,7 @@ class index_dense_gt {
}

// Pull the actual proximity graph
result = typed_->load_from_stream(std::forward<input_callback_at>(input));
result = typed_->load_from_stream(std::forward<input_callback_at>(input), std::forward<progress_at>(progress));
if (!result)
return result;
if (typed_->size() != static_cast<std::size_t>(matrix_rows))
Expand All @@ -931,7 +935,10 @@ class index_dense_gt {
* @param[in] config Configuration parameters for imports.
* @return Outcome descriptor explicitly convertible to boolean.
*/
serialization_result_t view(memory_mapped_file_t file, std::size_t offset = 0, serialization_config_t config = {}) {
template <typename progress_at = dummy_progress_t>
serialization_result_t view(memory_mapped_file_t file, //
std::size_t offset = 0, serialization_config_t config = {}, //
progress_at&& progress = {}) {

// Discard all previous memory allocations of `vectors_tape_allocator_`
reset();
Expand Down Expand Up @@ -997,7 +1004,7 @@ class index_dense_gt {
}

// Pull the actual proximity graph
result = typed_->view(std::move(file), offset);
result = typed_->view(std::move(file), offset, std::forward<progress_at>(progress));
if (!result)
return result;
if (typed_->size() != static_cast<std::size_t>(matrix_rows))
Expand All @@ -1019,7 +1026,9 @@ class index_dense_gt {
* @param[in] config Configuration parameters for exports.
* @return Outcome descriptor explicitly convertible to boolean.
*/
serialization_result_t save(output_file_t file, serialization_config_t config = {}) const {
template <typename progress_at = dummy_progress_t>
serialization_result_t save(output_file_t file, serialization_config_t config = {},
progress_at&& progress = {}) const {

serialization_result_t io_result = file.open_if_not();
if (!io_result)
Expand All @@ -1030,7 +1039,7 @@ class index_dense_gt {
io_result = file.write(buffer, length);
return !!io_result;
},
config);
config, std::forward<progress_at>(progress));

if (!stream_result)
return stream_result;
Expand All @@ -1041,8 +1050,11 @@ class index_dense_gt {
* @brief Memory-maps the serialized binary index representation from disk,
* @b without copying data into RAM, and fetching it on-demand.
*/
serialization_result_t save(memory_mapped_file_t file, std::size_t offset = 0,
serialization_config_t config = {}) const {
template <typename progress_at = dummy_progress_t>
serialization_result_t save(memory_mapped_file_t file, //
std::size_t offset = 0, //
serialization_config_t config = {}, //
progress_at&& progress = {}) const {

serialization_result_t io_result = file.open_if_not();
if (!io_result)
Expand All @@ -1056,7 +1068,7 @@ class index_dense_gt {
offset += length;
return true;
},
config);
config, std::forward<progress_at>(progress));

return stream_result;
}
Expand All @@ -1067,7 +1079,8 @@ class index_dense_gt {
* @param[in] config Configuration parameters for imports.
* @return Outcome descriptor explicitly convertible to boolean.
*/
serialization_result_t load(input_file_t file, serialization_config_t config = {}) {
template <typename progress_at = dummy_progress_t>
serialization_result_t load(input_file_t file, serialization_config_t config = {}, progress_at&& progress = {}) {

serialization_result_t io_result = file.open_if_not();
if (!io_result)
Expand All @@ -1078,7 +1091,7 @@ class index_dense_gt {
io_result = file.read(buffer, length);
return !!io_result;
},
config);
config, std::forward<progress_at>(progress));

if (!stream_result)
return stream_result;
Expand All @@ -1089,7 +1102,11 @@ class index_dense_gt {
* @brief Memory-maps the serialized binary index representation from disk,
* @b without copying data into RAM, and fetching it on-demand.
*/
serialization_result_t load(memory_mapped_file_t file, std::size_t offset = 0, serialization_config_t config = {}) {
template <typename progress_at = dummy_progress_t>
serialization_result_t load(memory_mapped_file_t file, //
std::size_t offset = 0, //
serialization_config_t config = {}, //
progress_at&& progress = {}) {

serialization_result_t io_result = file.open_if_not();
if (!io_result)
Expand All @@ -1103,17 +1120,23 @@ class index_dense_gt {
offset += length;
return true;
},
config);
config, std::forward<progress_at>(progress));

return stream_result;
}

serialization_result_t save(char const* file_path, serialization_config_t config = {}) const {
return save(output_file_t(file_path), config);
template <typename progress_at = dummy_progress_t>
serialization_result_t save(char const* file_path, //
serialization_config_t config = {}, //
progress_at&& progress = {}) const {
return save(output_file_t(file_path), config, std::forward<progress_at>(progress));
}

serialization_result_t load(char const* file_path, serialization_config_t config = {}) {
return load(input_file_t(file_path), config);
template <typename progress_at = dummy_progress_t>
serialization_result_t load(char const* file_path, //
serialization_config_t config = {}, //
progress_at&& progress = {}) {
return load(input_file_t(file_path), config, std::forward<progress_at>(progress));
}

/**
Expand Down
Loading
Loading