Skip to content

Commit

Permalink
Optimize: don't reread node for neighbor list
Browse files Browse the repository at this point in the history
  • Loading branch information
cevian committed Dec 14, 2023
1 parent 8955ced commit 041db69
Show file tree
Hide file tree
Showing 3 changed files with 49 additions and 61 deletions.
19 changes: 6 additions & 13 deletions timescale_vector/src/access_method/builder_graph.rs
Original file line number Diff line number Diff line change
Expand Up @@ -86,21 +86,14 @@ impl<'a> Graph for BuilderGraph<'a> {
self.meta_page.get_init_ids()
}

fn get_neighbors(
&self,
_index: &PgRelation,
neighbors_of: ItemPointer,
result: &mut Vec<IndexPointer>,
) -> bool {
fn get_neighbors(&self, _node: &ArchivedNode, neighbors_of: ItemPointer) -> Vec<IndexPointer> {
let neighbors = self.neighbor_map.get(&neighbors_of);
match neighbors {
Some(n) => {
for nwd in n {
result.push(nwd.get_index_pointer_to_neighbor());
}
true
}
None => false,
Some(n) => n
.iter()
.map(|n| n.get_index_pointer_to_neighbor())
.collect(),
None => vec![],
}
}

Expand Down
17 changes: 6 additions & 11 deletions timescale_vector/src/access_method/disk_index_graph.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
use pgrx::PgRelation;

use crate::util::{IndexPointer, ItemPointer};
use crate::util::ItemPointer;

use super::{
graph::{Graph, VectorProvider},
meta_page::MetaPage,
model::{NeighborWithDistance, Node, ReadableNode},
model::{ArchivedNode, NeighborWithDistance, Node, ReadableNode},
};

pub struct DiskIndexGraph<'a> {
Expand Down Expand Up @@ -36,18 +36,13 @@ impl<'h> Graph for DiskIndexGraph<'h> {
self.meta_page.get_init_ids()
}

fn get_neighbors(
&self,
index: &PgRelation,
neighbors_of: ItemPointer,
result: &mut Vec<IndexPointer>,
) -> bool {
let rn = self.read(index, neighbors_of);
rn.get_archived_node().apply_to_neighbors(|n| {
fn get_neighbors(&self, node: &ArchivedNode, _neighbors_of: ItemPointer) -> Vec<ItemPointer> {
let mut result = Vec::with_capacity(node.num_neighbors());
node.apply_to_neighbors(|n| {
let n = n.deserialize_item_pointer();
result.push(n)
});
true
result
}

fn get_neighbors_with_distances(
Expand Down
74 changes: 37 additions & 37 deletions timescale_vector/src/access_method/graph.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ use crate::access_method::model::Node;
use crate::util::ports::slot_getattr;
use crate::util::{HeapPointer, IndexPointer, ItemPointer};

use super::model::PgVector;
use super::model::{ArchivedNode, PgVector};
use super::quantizer::Quantizer;
use super::{
meta_page::MetaPage,
Expand Down Expand Up @@ -128,44 +128,37 @@ impl<'a> VectorProvider<'a> {

unsafe fn get_distance(
&self,
index: &PgRelation,
index_pointer: IndexPointer,
node: &ArchivedNode,
query: &[f32],
dm: &DistanceMeasure,
stats: &mut GreedySearchStats,
) -> (f32, HeapPointer) {
) -> f32 {
if self.calc_distance_with_quantizer {
let rn = unsafe { Node::read(index, index_pointer) };
stats.node_reads += 1;
let node = rn.get_archived_node();
assert!(node.pq_vector.len() > 0);
let vec = node.pq_vector.as_slice();
let distance = dm.get_quantized_distance(vec);
stats.pq_distance_comparisons += 1;
stats.distance_comparisons += 1;
return (distance, node.heap_item_pointer.deserialize_item_pointer());
return distance;
}

//now we know we're doing a distance calc on the full-sized vector
if self.quantizer.is_some() {
//have to get it from the heap
let heap_pointer = self.get_heap_pointer(index, index_pointer);
stats.node_reads += 1;
let heap_pointer = node.heap_item_pointer.deserialize_item_pointer();
let slot = TableSlot::new(self.heap_rel.unwrap());
self.init_slot(&slot, heap_pointer);
let slice = self.get_slice(&slot);
let distance = dm.get_full_vector_distance(slice, query);
stats.distance_comparisons += 1;
return (dm.get_full_vector_distance(slice, query), heap_pointer);
return distance;
} else {
//have to get it from the index
let rn = unsafe { Node::read(index, index_pointer) };
stats.node_reads += 1;
let node = rn.get_archived_node();
assert!(node.vector.len() > 0);
let vec = node.vector.as_slice();
let distance = dm.get_full_vector_distance(vec, query);
stats.distance_comparisons += 1;
return (distance, node.heap_item_pointer.deserialize_item_pointer());
return distance;
}
}

Expand Down Expand Up @@ -262,6 +255,7 @@ impl DistanceMeasure {
struct ListSearchNeighbor {
index_pointer: IndexPointer,
heap_pointer: HeapPointer,
neighbor_index_pointers: Vec<IndexPointer>,
distance: f32,
visited: bool,
}
Expand All @@ -279,10 +273,16 @@ impl PartialEq for ListSearchNeighbor {
}

impl ListSearchNeighbor {
pub fn new(index_pointer: IndexPointer, heap_pointer: HeapPointer, distance: f32) -> Self {
pub fn new(
index_pointer: IndexPointer,
heap_pointer: HeapPointer,
distance: f32,
neighbor_index_pointers: Vec<IndexPointer>,
) -> Self {
Self {
index_pointer,
heap_pointer,
neighbor_index_pointers,
distance,
visited: false,
}
Expand Down Expand Up @@ -349,12 +349,21 @@ impl ListSearchResult {
return;
}

let rn = unsafe { Node::read(index, index_pointer) };
self.stats.node_reads += 1;
let node = rn.get_archived_node();

let vp = graph.get_vector_provider();
let (dist, heap_pointer) =
unsafe { vp.get_distance(index, index_pointer, query, &self.dm, &mut self.stats) };
let distance = unsafe { vp.get_distance(node, query, &self.dm, &mut self.stats) };

let neighbor = ListSearchNeighbor::new(index_pointer, heap_pointer, dist);
self._insert_neighbor(neighbor);
let neighbors = graph.get_neighbors(node, index_pointer);
let lsn = ListSearchNeighbor::new(
index_pointer,
node.heap_item_pointer.deserialize_item_pointer(),
distance,
neighbors,
);
self._insert_neighbor(lsn);
}

/// Internal function
Expand All @@ -374,7 +383,7 @@ impl ListSearchResult {
self.best_candidate.insert(idx, n)
}

fn visit_closest(&mut self, pos_limit: usize) -> Option<(ItemPointer, f32)> {
fn visit_closest(&mut self, pos_limit: usize) -> Option<&ListSearchNeighbor> {
//OPT: should we optimize this not to do a linear search each time?
let neighbor_position = self.best_candidate.iter().position(|n| !n.visited);
match neighbor_position {
Expand All @@ -384,7 +393,7 @@ impl ListSearchResult {
}
let n = &mut self.best_candidate[pos];
n.visited = true;
Some((n.index_pointer, n.distance))
Some(n)
}
None => None,
}
Expand All @@ -404,12 +413,7 @@ impl ListSearchResult {
pub trait Graph {
fn read<'a>(&self, index: &'a PgRelation, index_pointer: ItemPointer) -> ReadableNode<'a>;
fn get_init_ids(&mut self) -> Option<Vec<ItemPointer>>;
fn get_neighbors(
&self,
index: &PgRelation,
neighbors_of: ItemPointer,
result: &mut Vec<IndexPointer>,
) -> bool;
fn get_neighbors(&self, node: &ArchivedNode, neighbors_of: ItemPointer) -> Vec<IndexPointer>;
fn get_neighbors_with_distances(
&self,
index: &PgRelation,
Expand Down Expand Up @@ -499,20 +503,16 @@ pub trait Graph {
Self: Graph,
{
//OPT: Only build v when needed.
let mut v: HashSet<_> = HashSet::<NeighborWithDistance>::with_capacity(visit_n_closest);
let mut neighbors =
Vec::<IndexPointer>::with_capacity(self.get_meta_page(index).get_num_neighbors() as _);
while let Some((index_pointer, distance)) = lsr.visit_closest(visit_n_closest) {
let mut v: HashSet<_> = HashSet::<NeighborWithDistance>::with_capacity(visit_n_closest);
while let Some(node) = lsr.visit_closest(visit_n_closest) {
neighbors.clear();
let neighbors_existed = self.get_neighbors(index, index_pointer, &mut neighbors);
if !neighbors_existed {
panic!("Nodes in the list search results that aren't in the builder");
}

for neighbor_index_pointer in &neighbors {
v.insert(NeighborWithDistance::new(node.index_pointer, node.distance));
neighbors.extend_from_slice(node.neighbor_index_pointers.as_slice());
for neighbor_index_pointer in neighbors.iter() {
lsr.insert(index, self, *neighbor_index_pointer, query)
}
v.insert(NeighborWithDistance::new(index_pointer, distance));
}

Some(v)
Expand Down

0 comments on commit 041db69

Please sign in to comment.