Skip to content

Commit

Permalink
query improved
Browse files Browse the repository at this point in the history
  • Loading branch information
brj0 committed Jul 16, 2023
1 parent 190c153 commit 7fd73ca
Show file tree
Hide file tree
Showing 9 changed files with 133 additions and 45 deletions.
2 changes: 1 addition & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
cmake_minimum_required(VERSION 3.10)

project(nndescent
VERSION 1.0.3
VERSION 1.0.4
LANGUAGES CXX
DESCRIPTION "A C++ implementation of the nearest neighbour descent algorithm")

Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,8 @@ def get_version_string():
"-flto",
"-fno-math-errno",
"-fopenmp",
"-g",
"-march=native",
"-mtune=native",
]

module = Extension(
Expand Down
4 changes: 2 additions & 2 deletions src/dtypes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ void print_map(Matrix<float> matrix)
}


std::ostream& operator<<(std::ostream &out, NNUpdate &update)
std::ostream& operator<<(std::ostream &out, const NNUpdate &update)
{
out << "(idx0=" << update.idx0
<< ", idx1=" << update.idx1
Expand All @@ -67,7 +67,7 @@ std::ostream& operator<<(std::ostream &out, NNUpdate &update)
}


std::ostream& operator<<(std::ostream &out, std::vector<NNUpdate> &updates)
std::ostream& operator<<(std::ostream &out, const std::vector<NNUpdate> &updates)
{
out << "[";
for (size_t i = 0; i < updates.size(); ++i)
Expand Down
85 changes: 83 additions & 2 deletions src/dtypes.h
Original file line number Diff line number Diff line change
Expand Up @@ -1012,6 +1012,18 @@ class HeapList
*/
int checked_push(size_t i, int idx, KeyType key);

/*
* Pushes a node with the specified index and key into the specified heap
* if its key is smaller.
*
* @param i The index of the heap.
* @param idx The index of the node.
* @param key The key associated with the node.
*
* @return 1 if the node was added to the heap, 0 otherwise.
*/
int simple_push(size_t i, int idx, KeyType key);

/*
* Performs a "siftdown" operation on the specified heap starting from the
* given index.
Expand Down Expand Up @@ -1279,6 +1291,75 @@ int HeapList<KeyType>::checked_push(size_t i, int idx, KeyType key)
}


template <class KeyType>
int HeapList<KeyType>::simple_push(size_t i, int idx, KeyType key)
{
if (key >= keys(i, 0))
{
return 0;
}

// Siftdown: Descend the heap, swapping values until the max heap
// criterion is met.
size_t current = 0;
size_t swap;

while (true)
{
size_t left_child = 2*current + 1;
size_t right_child = left_child + 1;

if (left_child >= n_nodes)
{
break;
}
else if (right_child >= n_nodes)
{
if (keys(i, left_child) > key)
{
swap = left_child;
}
else
{
break;
}
}
else if (keys(i, left_child) >= keys(i, right_child))
{
if (keys(i, left_child) > key)
{
swap = left_child;
}
else
{
break;
}
}
else
{
if (keys(i, right_child) > key)
{
swap = right_child;
}
else
{
break;
}
}
indices(i, current) = indices(i, swap);
keys(i, current) = keys(i, swap);

current = swap;
}

// Insert node at current position.
indices(i, current) = idx;
keys(i, current) = key;

return 1;
}


template <class KeyType>
size_t HeapList<KeyType>::size(size_t i) const
{
Expand Down Expand Up @@ -1409,13 +1490,13 @@ typedef struct
/*
* @brief Prints a NNUpdate object to an output stream.
*/
std::ostream& operator<<(std::ostream &out, NNUpdate &update);
std::ostream& operator<<(std::ostream &out, const NNUpdate &update);


/*
* @brief Prints a vector of NNUpdate objects to an output stream.
*/
std::ostream& operator<<(std::ostream &out, std::vector<NNUpdate> &updates);
std::ostream& operator<<(std::ostream &out, const std::vector<NNUpdate> &updates);


} // namespace nndescent
14 changes: 8 additions & 6 deletions src/nnd.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ void update_by_leaves(
{
int block_start = thread * block_size;
int block_end = (thread + 1) * block_size;
block_end = (thread == n_threads) ? n_leaves : block_end;
block_end = (thread == n_threads - 1) ? n_leaves : block_end;

for (int i = block_start; i < block_end; ++i)
{
Expand Down Expand Up @@ -186,7 +186,7 @@ void sample_candidates(
HeapList<float> &current_graph,
HeapList<int> &new_candidates,
HeapList<int> &old_candidates,
RandomState rng_state,
const RandomState &rng_state,
int n_threads
)
{
Expand Down Expand Up @@ -363,7 +363,7 @@ std::vector<std::vector<NNUpdate>> generate_graph_updates(
*/
int apply_graph_updates(
HeapList<float> &current_graph,
std::vector<std::vector<NNUpdate>> &updates,
const std::vector<std::vector<NNUpdate>> &updates,
int n_threads
)
{
Expand Down Expand Up @@ -825,8 +825,8 @@ void NNDescent::prepare(const DistType &dist)
);
}

size_t n_seach_cols = std::round(n_neighbors * pruning_degree_multiplier);
search_graph = HeapList<float>(data_size, n_seach_cols, FLOAT_MAX);
size_t n_search_cols = std::round(n_neighbors * pruning_degree_multiplier);
search_graph = HeapList<float>(data_size, n_search_cols, FLOAT_MAX);

for (size_t i = 0; i < forward_graph.nheaps(); ++i)
{
Expand All @@ -841,6 +841,7 @@ void NNDescent::prepare(const DistType &dist)
}
}
}

search_graph.heapsort();

if (verbose)
Expand All @@ -851,6 +852,7 @@ void NNDescent::prepare(const DistType &dist)
+ " edges for the search graph."
);
}

}


Expand All @@ -867,7 +869,7 @@ void NNDescent::start_brute_force(
for (size_t idx1 = 0; idx1 < train_data.nrows(); ++idx1)
{
float d = dist(train_data, idx0, idx1);
current_graph.checked_push(idx0, idx1, d);
current_graph.simple_push(idx0, idx1, d);
}
}
current_graph.heapsort();
Expand Down
33 changes: 19 additions & 14 deletions src/nnd.h
Original file line number Diff line number Diff line change
Expand Up @@ -54,9 +54,9 @@ namespace nndescent


/**
* @version 1.0.3
* @version 1.0.4
*/
const std::string PROJECT_VERSION = "1.0.3";
const std::string PROJECT_VERSION = "1.0.4";


// Constants
Expand Down Expand Up @@ -560,7 +560,7 @@ void NNDescent::query_brute_force(
for (size_t idx_t = 0; idx_t < data_size; ++idx_t)
{
float d = dist(train_data, idx_t, query_data, idx_q);
query_nn.checked_push(idx_q, idx_t, d);
query_nn.simple_push(idx_q, idx_t, d);
}
}
query_nn.heapsort();
Expand Down Expand Up @@ -830,12 +830,13 @@ void NNDescent::query(
prepare(dist);
}
HeapList<float> query_nn(_query_data.nrows(), k, FLOAT_MAX);
#pragma omp parallel for num_threads(n_threads)
for (size_t i = 0; i < query_nn.nheaps(); ++i)
{

// Initialization
Heap<Candidate> search_candidates;
std::vector<int> visited(data_size, 0);
std::vector<bool> visited(data_size, 0);
std::vector<int> initial_candidates = search_tree.get_leaf(
_query_data, i, rng_state
);
Expand All @@ -844,10 +845,9 @@ void NNDescent::query(
{
float d = dist(train_data, idx, _query_data, i);
// Don't need to check as indices are guaranteed to be different.
// TODO implement push without check.
query_nn.checked_push(i, idx, d);
visited[idx] = 1;
query_nn.simple_push(i, idx, d);
search_candidates.push({idx, d});
visited[idx] = 1;
}
int n_random_samples = k - initial_candidates.size();
for (int j = 0; j < n_random_samples; ++j)
Expand All @@ -856,9 +856,9 @@ void NNDescent::query(
if (!visited[idx])
{
float d = dist(train_data, idx, _query_data, i);
query_nn.checked_push(i, idx, d);
visited[idx] = 1;
query_nn.simple_push(i, idx, d);
search_candidates.push({idx, d});
visited[idx] = 1;
}
}

Expand All @@ -867,9 +867,13 @@ void NNDescent::query(
float distance_bound = (1.0f + epsilon) * query_nn.max(i);
while (candidate.key < distance_bound)
{
for (size_t j = 0; j < search_graph.nnodes(); ++j)
for (
auto it = search_graph.indices.begin(candidate.idx);
it != search_graph.indices.end(candidate.idx);
++it
)
{
int idx = search_graph.indices(candidate.idx, j);
int idx = *it;
if (idx == NONE)
{
break;
Expand All @@ -882,14 +886,13 @@ void NNDescent::query(
float d = dist(train_data, idx, _query_data, i);
if (d < distance_bound)
{
query_nn.checked_push(i, idx, d);
query_nn.simple_push(i, idx, d);
search_candidates.push({idx, d});

// Update bound
distance_bound = (1.0f + epsilon) * query_nn.max(i);
}
}
// Find new nearest candidate point.
// The next candidate is the nearest among the search_candidates.
if (search_candidates.empty())
{
break;
Expand All @@ -899,7 +902,9 @@ void NNDescent::query(
candidate = search_candidates.pop();
}
}

}

query_nn.heapsort();
query_indices = query_nn.indices;
query_distances = query_nn.keys;
Expand Down
10 changes: 5 additions & 5 deletions src/rp_trees.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ template<>
std::tuple<std::vector<int>, std::vector<int>, std::vector<float>, float>
random_projection_split<EuclideanSplit>(
const Matrix<float> &data,
std::vector<int> &indices,
const std::vector<int> &indices,
RandomState &rng_state
)
{
Expand Down Expand Up @@ -136,7 +136,7 @@ template<>
std::tuple<std::vector<int>, std::vector<int>, std::vector<float>, float>
random_projection_split<AngularSplit>(
const Matrix<float> &data,
std::vector<int> &indices,
const std::vector<int> &indices,
RandomState &rng_state
)
{
Expand Down Expand Up @@ -295,7 +295,7 @@ std::tuple<
>
sparse_random_projection_split<EuclideanSplit>(
const CSRMatrix<float> &data,
std::vector<int> &indices,
const std::vector<int> &indices,
RandomState &rng_state
)
{
Expand Down Expand Up @@ -443,7 +443,7 @@ std::tuple<
>
sparse_random_projection_split<AngularSplit>(
const CSRMatrix<float> &data,
std::vector<int> &indices,
const std::vector<int> &indices,
RandomState &rng_state
)
{
Expand Down Expand Up @@ -616,7 +616,7 @@ sparse_random_projection_split<AngularSplit>(


Matrix<int> get_leaves_from_forest(
std::vector<RPTree> &forest
const std::vector<RPTree> &forest
)
{
size_t leaf_size = forest[0].leaf_size;
Expand Down
Loading

0 comments on commit 7fd73ca

Please sign in to comment.