Skip to content

Commit

Permalink
checkpoint
Browse files Browse the repository at this point in the history
  • Loading branch information
tjgreen42 committed Jan 21, 2025
1 parent f2b0020 commit 0c12de6
Show file tree
Hide file tree
Showing 11 changed files with 430 additions and 97 deletions.
4 changes: 2 additions & 2 deletions pgvectorscale/src/access_method/build.rs
Original file line number Diff line number Diff line change
Expand Up @@ -232,7 +232,7 @@ unsafe fn insert_storage<S: Storage>(
let mut tape = Tape::resume(index_relation, S::page_type());
let index_pointer = storage.create_node(
vec.vec().to_index_slice(),
vec.labels().cloned(),
vec.labels(),
heap_pointer,
meta_page,
&mut tape,
Expand Down Expand Up @@ -509,7 +509,7 @@ fn build_callback_internal<S: Storage>(

let index_pointer = storage.create_node(
vec.vec().to_index_slice(),
vec.labels().cloned(),
vec.labels(),
heap_pointer,
&state.meta_page,
&mut state.tape,
Expand Down
85 changes: 61 additions & 24 deletions pgvectorscale/src/access_method/graph.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,22 +5,21 @@ use std::{cmp::Ordering, collections::HashSet};
use pgrx::PgRelation;

use crate::access_method::storage::NodeDistanceMeasure;

use crate::util::{HeapPointer, IndexPointer, ItemPointer};

use super::graph_neighbor_store::GraphNeighborStore;

use super::labels::LabeledVector;
use super::neighbor_with_distance::{Distance, DistanceWithTieBreak};
use super::pg_vector::PgVector;
use super::labels::{label_vec_to_set, Label, LabelSet, LabeledVector};
use super::meta_page::MetaPage;
use super::neighbor_with_distance::{Distance, DistanceWithTieBreak, NeighborWithDistance};
use super::start_nodes::StartNodes;
use super::stats::{GreedySearchStats, InsertStats, PruneNeighborStats, StatsNodeVisit};
use super::storage::Storage;
use super::{meta_page::MetaPage, neighbor_with_distance::NeighborWithDistance};

pub struct ListSearchNeighbor<PD> {
pub index_pointer: IndexPointer,
distance_with_tie_break: DistanceWithTieBreak,
private_data: PD,
labels: LabelSet,
}

impl<PD> PartialOrd for ListSearchNeighbor<PD> {
Expand Down Expand Up @@ -49,17 +48,23 @@ impl<PD> ListSearchNeighbor<PD> {
index_pointer: IndexPointer,
distance_with_tie_break: DistanceWithTieBreak,
private_data: PD,
labels: &[Label],
) -> Self {
Self {
index_pointer,
private_data,
distance_with_tie_break,
labels: labels.try_into().unwrap(),
}
}

pub fn get_private_data(&self) -> &PD {
&self.private_data
}

pub fn get_labels(&self) -> &LabelSet {
&self.labels
}
}

pub struct ListSearchResult<QDM, PD> {
Expand Down Expand Up @@ -105,7 +110,7 @@ impl<QDM, PD> ListSearchResult<QDM, PD> {
};
res.stats.record_call();
for index_pointer in init_ids {
let lsn = storage.create_lsn_for_init_id(&mut res, index_pointer, gns);
let lsn = storage.create_lsn_for_start_node(&mut res, index_pointer, gns);
res.insert_neighbor(lsn);
}
res
Expand Down Expand Up @@ -152,8 +157,8 @@ impl<QDM, PD> ListSearchResult<QDM, PD> {
}

let head = self.candidates.pop().unwrap();
let idx = self.visited.partition_point(|x| *x < head.0);
self.visited.insert(idx, head.0);
let idx = self.visited.partition_point(|x| *x < head.0); // TODO: O(n)
self.visited.insert(idx, head.0); // TODO: O(n)
Some(idx)
}

Expand Down Expand Up @@ -189,8 +194,8 @@ impl<'a> Graph<'a> {
&self.neighbor_store
}

fn get_init_ids(&self) -> Option<Vec<ItemPointer>> {
self.meta_page.get_init_ids()
pub fn get_start_nodes(&self) -> Option<&StartNodes> {
self.meta_page.get_start_nodes()
}

fn add_neighbors<S: Storage>(
Expand Down Expand Up @@ -274,16 +279,17 @@ impl<'a> Graph<'a> {
storage: &S,
stats: &mut GreedySearchStats,
) -> HashSet<NeighborWithDistance> {
let init_ids = self.get_init_ids();
let init_ids = self.get_start_nodes();
if init_ids.is_none() {
//no nodes in the graph
return HashSet::with_capacity(0);
}
let start_nodes = init_ids.unwrap().get_for_node(query.labels());
let dm = storage.get_query_distance_measure(query);
let search_list_size = meta_page.get_search_list_size_for_build() as usize;

let mut l = ListSearchResult::new(
init_ids.unwrap(),
start_nodes,
dm,
Some(index_pointer),
search_list_size,
Expand All @@ -305,15 +311,16 @@ impl<'a> Graph<'a> {
search_list_size: usize,
storage: &S,
) -> ListSearchResult<S::QueryDistanceMeasure, S::LSNPrivateData> {
let init_ids = self.get_init_ids();
if init_ids.is_none() {
//no nodes in the graph
let start_nodes = self.get_start_nodes();
if start_nodes.is_none() {
// No nodes in the graph
return ListSearchResult::empty();
}
let start_nodes = start_nodes.unwrap().get_for_node(query.labels());
let dm = storage.get_query_distance_measure(query);

ListSearchResult::new(
init_ids.unwrap(),
start_nodes,
dm,
None,
search_list_size,
Expand All @@ -339,6 +346,7 @@ impl<'a> Graph<'a> {
visited_nodes.insert(NeighborWithDistance::new(
list_search_entry.index_pointer,
list_search_entry.distance_with_tie_break.clone(),
list_search_entry.get_labels(),
));
}
}
Expand Down Expand Up @@ -438,18 +446,24 @@ impl<'a> Graph<'a> {
results
}

pub fn insert<S: Storage>(
fn update_start_nodes<S: Storage>(
&mut self,
index: &PgRelation,
index_pointer: IndexPointer,
vec: LabeledVector,
vec: &LabeledVector,
storage: &S,
stats: &mut InsertStats,
) {
if self.meta_page.get_init_ids().is_none() {
let start_nodes = self.meta_page.get_start_nodes();
if let Some(start_nodes) = start_nodes {
if start_nodes.contains(vec.labels()) {
return;
}
}

let mut start_nodes = if start_nodes.is_none() {
//TODO probably better set off of centeroids
MetaPage::update_init_ids(index, vec![index_pointer], stats);
*self.meta_page = MetaPage::fetch(index);
let start_nodes = StartNodes::new(index_pointer);

self.neighbor_store.set_neighbors(
storage,
Expand All @@ -460,12 +474,32 @@ impl<'a> Graph<'a> {
),
stats,
);
}

start_nodes
} else {
start_nodes.unwrap().clone()
};

start_nodes.add_node(vec.labels(), index_pointer);

MetaPage::set_start_nodes(index, start_nodes, stats);
}

pub fn insert<S: Storage>(
&mut self,
index: &PgRelation,
index_pointer: IndexPointer,
vec: LabeledVector,
storage: &S,
stats: &mut InsertStats,
) {
self.update_start_nodes(index, index_pointer, &vec, storage, stats);

let meta_page = self.get_meta_page();

//TODO: make configurable?
#[allow(clippy::mutable_key_type)]
let labels = label_vec_to_set(vec.labels());

let v = self.greedy_search_for_build(
index_pointer,
vec,
Expand All @@ -488,6 +522,7 @@ impl<'a> Graph<'a> {
let neighbor_contains_new_point = self.update_back_pointer(
neighbor.get_index_pointer_to_neighbor(),
index_pointer,
&labels,
neighbor.get_distance_with_tie_break(),
storage,
&mut stats.prune_neighbor_stats,
Expand All @@ -512,13 +547,15 @@ impl<'a> Graph<'a> {
&mut self,
from: IndexPointer,
to: IndexPointer,
to_labels: &LabelSet,
distance_with_tie_break: &DistanceWithTieBreak,
storage: &S,
prune_stats: &mut PruneNeighborStats,
) -> bool {
let new = vec![NeighborWithDistance::new(
to,
distance_with_tie_break.clone(),
to_labels,
)];
let (_pruned, n) = self.add_neighbors(storage, from, new.clone(), prune_stats);
n.contains(&new[0])
Expand Down
124 changes: 116 additions & 8 deletions pgvectorscale/src/access_method/labels.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,57 @@ use pgrx::{
Array, FromDatum,
};

pub type Labels = Vec<u16>;
pub type Label = u8;
pub type LabelSet = [Label; MAX_LABELS_PER_NODE];
pub type LabelVec = Vec<Label>;

pub const INVALID_LABEL: u8 = 255;
pub const MAX_LABELS_PER_NODE: usize = 8;
pub const MAX_LABEL: u8 = 254;

/// Returns true if the two label sets overlap. Assumes labels are sorted.
pub fn test_overlap(labels1: &[Label], labels2: &[Label]) -> bool {
debug_assert!(labels1.is_sorted());
debug_assert!(labels2.is_sorted());

let mut i = 0;
let mut j = 0;
while i < labels1.len() && j < labels2.len() {
if labels1[i] == labels2[j] {
return true;
} else if labels1[i] < labels2[j] {
i += 1;
} else {
j += 1;
}
}
false
}

pub fn label_vec_to_set(labels: Option<&[Label]>) -> LabelSet {
let mut set = [INVALID_LABEL; MAX_LABELS_PER_NODE];
if let Some(labels) = labels {
debug_assert!(labels.len() <= MAX_LABELS_PER_NODE);
debug_assert!(labels.is_sorted());

for (i, &label) in labels.iter().enumerate() {
set[i] = label;
}
}
set
}

pub fn new_labelset() -> LabelSet {
[INVALID_LABEL; MAX_LABELS_PER_NODE]
}

pub struct LabeledVector {
vec: PgVector,
labels: Option<Labels>,
labels: Option<LabelVec>,
}

impl LabeledVector {
pub fn new(vec: PgVector, labels: Option<Labels>) -> Self {
pub fn new(vec: PgVector, labels: Option<LabelVec>) -> Self {
Self { vec, labels }
}

Expand All @@ -25,7 +67,7 @@ impl LabeledVector {

let labels = if meta_page.has_labels() {
let arr = Array::<i32>::from_datum(*values.add(1), *isnull.add(1));
arr.map(|arr| arr.into_iter().flatten().map(|x| x as u16).collect())
arr.map(|arr| arr.into_iter().flatten().map(|x| x as Label).collect())
} else {
None
};
Expand All @@ -47,9 +89,9 @@ impl LabeledVector {
)
};

let labels: Option<Vec<u16>> = (!keys.is_empty()).then(|| {
let labels: Option<Vec<Label>> = (!keys.is_empty()).then(|| {
let arr = unsafe { Array::<i32>::from_datum(keys[0].sk_argument, false).unwrap() };
arr.into_iter().flatten().map(|i| i as u16).collect()
arr.into_iter().flatten().map(|i| i as Label).collect()
});

Self::new(query, labels)
Expand All @@ -59,7 +101,73 @@ impl LabeledVector {
&self.vec
}

pub fn labels(&self) -> Option<&Vec<u16>> {
self.labels.as_ref()
pub fn labels(&self) -> Option<&[Label]> {
self.labels.as_deref()
}
}

/// Test cases for test_overlap
#[cfg(test)]
mod test {
use super::*;

#[test]
fn test_test_overlap() {
assert_eq!(test_overlap(&[], &[]), false);
assert_eq!(test_overlap(&[1], &[]), false);
assert_eq!(test_overlap(&[], &[1]), false);
assert_eq!(test_overlap(&[1], &[1]), true);
assert_eq!(test_overlap(&[1], &[2]), false);
assert_eq!(test_overlap(&[1, 2], &[2]), true);
assert_eq!(test_overlap(&[1, 2], &[3]), false);
assert_eq!(test_overlap(&[1, 2], &[2, 3]), true);
assert_eq!(test_overlap(&[1, 2], &[3, 4]), false);
assert_eq!(test_overlap(&[1, 2], &[2, 3]), true);
assert_eq!(test_overlap(&[1, 2], &[2, 3, 4]), true);
assert_eq!(test_overlap(&[1, 2], &[3, 4, 5]), false);
}

/// Test label_vec_to_set
#[test]
fn test_label_vec_to_set() {
assert_eq!(label_vec_to_set(None), [INVALID_LABEL; MAX_LABELS_PER_NODE]);
assert_eq!(
label_vec_to_set(Some(&[])),
[INVALID_LABEL; MAX_LABELS_PER_NODE]
);
assert_eq!(
label_vec_to_set(Some(&[1])),
[
1,
INVALID_LABEL,
INVALID_LABEL,
INVALID_LABEL,
INVALID_LABEL,
INVALID_LABEL,
INVALID_LABEL,
INVALID_LABEL
]
);
assert_eq!(
label_vec_to_set(Some(&[1, 2])),
[
1,
2,
INVALID_LABEL,
INVALID_LABEL,
INVALID_LABEL,
INVALID_LABEL,
INVALID_LABEL,
INVALID_LABEL
]
);
assert_eq!(
label_vec_to_set(Some(&[1, 2, 3, 4, 5, 6, 7])),
[1, 2, 3, 4, 5, 6, 7, INVALID_LABEL]
);
assert_eq!(
label_vec_to_set(Some(&[1, 2, 3, 4, 5, 6, 7, 8])),
[1, 2, 3, 4, 5, 6, 7, 8]
);
}
}
Loading

0 comments on commit 0c12de6

Please sign in to comment.