Skip to content

Commit

Permalink
Raise diskann maximum dimension from 2K to 16K (#181)
Browse files Browse the repository at this point in the history
This PR fixes #100 and raises the dimension limit for pgvectorscale's
diskann index from 2000 to 16000, which is the maximum supported by the
underlying pgvector `vector` type.

The previous limit of 2000 was needed to ensure that all data structures
could be serialized onto single 8K pages. When going beyond 2000
dimensions, so long as SBQ is used for storage, quantized vectors,
neighbor lists, and other data structures will still fit on a single
page; the only thing that grows too large is `SbqMeans`. (The raw
vectors used for reranking remain in the source relation, where standard
Postgres TOAST machinery is used to read/write them). If plain storage
is used, the old limit of 2000 remains in place.

To deal with `SbqMeans`, we introduce a `ChainTape` data structure that
is similar to `Tape` but supports reads/writes of large buffers across
pages. The chained representation is considered a property of the
`PageType`, and we introduce a new `PageType` for `SbqMeans` along with
upgrade machinery from the old version. Similarly to the versioned
`MetaPage`, there are no unit tests for this, but I did ad-hoc testing
to confirm that the upgrade path works.
  • Loading branch information
tjgreen42 authored Dec 18, 2024
1 parent 8836117 commit 8bd2acf
Show file tree
Hide file tree
Showing 6 changed files with 476 additions and 56 deletions.
83 changes: 75 additions & 8 deletions pgvectorscale/src/access_method/build.rs
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,14 @@ impl<'a, 'b> BuildState<'a, 'b> {
}
}

/// Maximum number of dimensions supported by pgvector's vector type. Also
/// the maximum number of dimensions that can be indexed with diskann.
pub const MAX_DIMENSION: u32 = 16000;

/// Maximum number of dimensions that can be indexed with diskann without
/// using the SBQ storage type.
pub const MAX_DIMENSION_NO_SBQ: u32 = 2000;

#[pg_guard]
pub extern "C" fn ambuild(
heaprel: pg_sys::Relation,
Expand All @@ -73,7 +81,7 @@ pub extern "C" fn ambuild(
let opt = TSVIndexOptions::from_relation(&index_relation);

notice!(
"Starting index build. num_neighbors={} search_list_size={}, max_alpha={}, storage_layout={:?}",
"Starting index build with num_neighbors={}, search_list_size={}, max_alpha={}, storage_layout={:?}.",
opt.get_num_neighbors(),
opt.search_list_size,
opt.max_alpha,
Expand All @@ -98,10 +106,22 @@ pub extern "C" fn ambuild(
let meta_page =
unsafe { MetaPage::create(&index_relation, dimensions as _, distance_type, opt) };

assert!(
meta_page.get_num_dimensions_to_index() > 0
&& meta_page.get_num_dimensions_to_index() <= 2000
);
if meta_page.get_num_dimensions_to_index() == 0 {
error!("No dimensions to index");
}

if meta_page.get_num_dimensions_to_index() > MAX_DIMENSION {
error!("Too many dimensions to index (max is {})", MAX_DIMENSION);
}

if meta_page.get_num_dimensions_to_index() > MAX_DIMENSION_NO_SBQ
&& meta_page.get_storage_type() == StorageType::Plain
{
error!(
"Too many dimensions to index with plain storage (max is {}). Use storage_layout=memory_optimized instead.",
MAX_DIMENSION_NO_SBQ
);
}

let ntuples = do_heap_scan(index_info, &heap_relation, &index_relation, meta_page);

Expand Down Expand Up @@ -878,7 +898,7 @@ pub mod tests {
);
select setseed(0.5);
-- generate 300 vectors
-- generate {expected_cnt} vectors
INSERT INTO {table_name} (id, embedding)
SELECT
*
Expand Down Expand Up @@ -1036,7 +1056,7 @@ pub mod tests {
);
select setseed(0.5);
-- generate 300 vectors
-- generate {expected_cnt} vectors
INSERT INTO test_data (id, embedding)
SELECT
*
Expand Down Expand Up @@ -1086,7 +1106,7 @@ pub mod tests {
CREATE INDEX idx_diskann_bq ON test_data USING diskann (embedding) WITH ({index_options});
select setseed(0.5);
-- generate 300 vectors
-- generate {expected_cnt} vectors
INSERT INTO test_data (id, embedding)
SELECT
*
Expand Down Expand Up @@ -1114,4 +1134,51 @@ pub mod tests {
verify_index_accuracy(expected_cnt, dimensions)?;
Ok(())
}

#[pg_test]
pub unsafe fn test_high_dimension_index() -> spi::Result<()> {
let index_options = "num_neighbors=10, search_list_size=10";
let expected_cnt = 1000;

for dimensions in [4000, 8000, 12000, 16000] {
Spi::run(&format!(
"CREATE TABLE test_data (
id int,
embedding vector ({dimensions})
);
CREATE INDEX idx_diskann_bq ON test_data USING diskann (embedding) WITH ({index_options});
select setseed(0.5);
-- generate {expected_cnt} vectors
INSERT INTO test_data (id, embedding)
SELECT
*
FROM (
SELECT
i % {expected_cnt},
('[' || array_to_string(array_agg(random()), ',', '0') || ']')::vector AS embedding
FROM
generate_series(1, {dimensions} * {expected_cnt}) i
GROUP BY
i % {expected_cnt}) g;
SET enable_seqscan = 0;
-- perform index scans on the vectors
SELECT
*
FROM
test_data
ORDER BY
embedding <=> (
SELECT
('[' || array_to_string(array_agg(random()), ',', '0') || ']')::vector AS embedding
FROM generate_series(1, {dimensions}));"))?;

verify_index_accuracy(expected_cnt, dimensions)?;

Spi::run("DROP TABLE test_data CASCADE;")?;
}
Ok(())
}
}
100 changes: 79 additions & 21 deletions pgvectorscale/src/access_method/sbq.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,11 @@ use pgrx::{
use rkyv::{vec::ArchivedVec, Archive, Deserialize, Serialize};

use crate::util::{
page::PageType, table_slot::TableSlot, tape::Tape, ArchivedItemPointer, HeapPointer,
IndexPointer, ItemPointer, ReadableBuffer,
chain::{ChainItemReader, ChainTapeWriter},
page::{PageType, ReadablePage},
table_slot::TableSlot,
tape::Tape,
ArchivedItemPointer, HeapPointer, IndexPointer, ItemPointer, ReadableBuffer,
};

use super::{meta_page::MetaPage, neighbor_with_distance::NeighborWithDistance};
Expand All @@ -33,33 +36,28 @@ const BITS_STORE_TYPE_SIZE: usize = 64;
#[derive(Archive, Deserialize, Serialize, Readable, Writeable)]
#[archive(check_bytes)]
#[repr(C)]
pub struct SbqMeans {
pub struct SbqMeansV1 {
count: u64,
means: Vec<f32>,
m2: Vec<f32>,
}

impl SbqMeans {
impl SbqMeansV1 {
pub unsafe fn load<S: StatsNodeRead>(
index: &PgRelation,
meta_page: &super::meta_page::MetaPage,
mut quantizer: SbqQuantizer,
qip: ItemPointer,
stats: &mut S,
) -> SbqQuantizer {
let mut quantizer = SbqQuantizer::new(meta_page);
if quantizer.use_mean {
if meta_page.get_quantizer_metadata_pointer().is_none() {
pgrx::error!("No SBQ pointer found in meta page");
}
let quantizer_item_pointer = meta_page.get_quantizer_metadata_pointer().unwrap();
let bq = SbqMeans::read(index, quantizer_item_pointer, stats);
let archived = bq.get_archived_node();

quantizer.load(
archived.count,
archived.means.to_vec(),
archived.m2.to_vec(),
);
}
assert!(quantizer.use_mean);
let bq = SbqMeansV1::read(index, qip, stats);
let archived = bq.get_archived_node();

quantizer.load(
archived.count,
archived.means.to_vec(),
archived.m2.to_vec(),
);
quantizer
}

Expand All @@ -69,7 +67,7 @@ impl SbqMeans {
stats: &mut S,
) -> ItemPointer {
let mut tape = Tape::new(index, PageType::SbqMeans);
let node = SbqMeans {
let node = SbqMeansV1 {
count: quantizer.count,
means: quantizer.mean.to_vec(),
m2: quantizer.m2.to_vec(),
Expand All @@ -80,6 +78,66 @@ impl SbqMeans {
}
}

#[derive(Archive, Deserialize, Serialize)]
#[archive(check_bytes)]
#[repr(C)]
pub struct SbqMeans {
count: u64,
means: Vec<f32>,
m2: Vec<f32>,
}

impl SbqMeans {
pub unsafe fn load<S: StatsNodeRead>(
index: &PgRelation,
meta_page: &super::meta_page::MetaPage,
stats: &mut S,
) -> SbqQuantizer {
let mut quantizer = SbqQuantizer::new(meta_page);
if !quantizer.use_mean {
return quantizer;
}
let qip = meta_page
.get_quantizer_metadata_pointer()
.unwrap_or_else(|| pgrx::error!("No SBQ pointer found in meta page"));

let page = ReadablePage::read(index, qip.block_number);
let page_type = page.get_type();
match page_type {
PageType::SbqMeansV1 => SbqMeansV1::load(index, quantizer, qip, stats),
PageType::SbqMeans => {
let mut tape_reader = ChainItemReader::new(index, PageType::SbqMeans, stats);
let mut buf: Vec<u8> = Vec::new();
for item in tape_reader.read(qip) {
buf.extend_from_slice(item.get_data_slice());
}

let means = rkyv::from_bytes::<SbqMeans>(buf.as_slice()).unwrap();
quantizer.load(means.count, means.means, means.m2);
quantizer
}
_ => {
pgrx::error!("Invalid page type {} for SbqMeans", page_type as u8);
}
}
}

pub unsafe fn store<S: StatsNodeWrite>(
index: &PgRelation,
quantizer: &SbqQuantizer,
stats: &mut S,
) -> ItemPointer {
let bq = SbqMeans {
count: quantizer.count,
means: quantizer.mean.clone(),
m2: quantizer.m2.clone(),
};
let mut tape = ChainTapeWriter::new(index, PageType::SbqMeans, stats);
let buf = rkyv::to_bytes::<_, 1024>(&bq).unwrap();
tape.write(&buf)
}
}

#[derive(Clone)]
pub struct SbqQuantizer {
pub use_mean: bool,
Expand Down
Loading

0 comments on commit 8bd2acf

Please sign in to comment.