From 9b1fe95d3faaac7b4e2151023858783b64d08f38 Mon Sep 17 00:00:00 2001 From: Siddhant Chaudhary Date: Sun, 15 Sep 2024 18:06:18 +0530 Subject: [PATCH 1/3] Adding tests for search functions, and related minor changes. --- src/search/ranking.jl | 13 ++- src/searching.jl | 1 - test/indexing/codecs/residual.jl | 2 +- test/indexing/collection_indexer.jl | 3 +- test/runtests.jl | 4 + test/search/ranking.jl | 162 ++++++++++++++++++++++++++++ test/searching.jl | 43 ++++++++ 7 files changed, 221 insertions(+), 7 deletions(-) create mode 100644 test/search/ranking.jl create mode 100644 test/searching.jl diff --git a/src/search/ranking.jl b/src/search/ranking.jl index ce1392c..a838e75 100644 --- a/src/search/ranking.jl +++ b/src/search/ranking.jl @@ -6,7 +6,10 @@ Get the set of embedding IDs contained in `centroid_ids`. """ function _cids_to_eids!(eids::Vector{Int}, centroid_ids::Vector{Int}, ivf::Vector{Int}, ivf_lengths::Vector{Int}) - @assert length(eids) == sum(ivf_lengths[centroid_ids]) + length(eids) == sum(ivf_lengths[centroid_ids]) || + throw(DimensionMismatch("length(eids) must be equal to sum(ivf_lengths[centroid_ids])!")) + length(ivf) == sum(ivf_lengths) || + throw(DimensionMismatch("length(ivf) must be equal to sum(ivf_lengths)!")) centroid_ivf_offsets = cumsum([1; _head(ivf_lengths)]) eid_offsets = cumsum([1; _head(ivf_lengths[centroid_ids])]) for (idx, centroid_id) in enumerate(centroid_ids) @@ -65,17 +68,19 @@ end function maxsim(Q::AbstractMatrix{Float32}, D::AbstractMatrix{Float32}, pids::Vector{Int}, doclens::Vector{Int}) + sum(doclens[pids]) == size(D, 2) || + throw(DimensionMismatch("The total number of embeddings " * + "for pids does not match with the " * " + dimension of D!")) scores = zeros(Float32, length(pids)) - num_embeddings = sum(doclens[pids]) query_doc_scores = Q' * D offsets = cumsum([1; _head(doclens[pids])]) for (idx, pid) in enumerate(pids) num_embs_pids = doclens[pid] offset = offsets[idx] - offset_end = min(num_embeddings, offset + num_embs_pids - 1) + offset_end = offset + num_embs_pids - 1 pid_scores = query_doc_scores[:, offset:offset_end] scores[idx] = sum(maximum(pid_scores, dims = 2)) - offset += num_embs_pids end scores end diff --git a/src/searching.jl b/src/searching.jl index c4046d6..e617931 100644 --- a/src/searching.jl +++ b/src/searching.jl @@ -87,7 +87,6 @@ function _build_emb2pid(doclens::Vector{Int}) offset = embs2pid_offsets[pid] emb2pid[offset:(offset + dlength - 1)] .= pid end - @assert all(!=(0), emb2pid) emb2pid end diff --git a/test/indexing/codecs/residual.jl b/test/indexing/codecs/residual.jl index 56180c7..ab1b12c 100644 --- a/test/indexing/codecs/residual.jl +++ b/test/indexing/codecs/residual.jl @@ -23,7 +23,7 @@ using ColBERT: _normalize_array!, compress_into_codes!, _binarize, _unbinarize, @test isequal(codes, sortperm(perm)) # sortperm(perm) -> inverse mapping # Test 3: sample centroids randomly from embeddings - embs = rand(Float32, rand(1:128), rand(1:20)) + embs = rand(Float32, rand(2:128), rand(1:20)) _normalize_array!(embs; dims = 1) perm = collect(take(randperm(size(embs, 2)), rand(1:size(embs, 2)))) centroids = embs[:, perm] diff --git a/test/indexing/collection_indexer.jl b/test/indexing/collection_indexer.jl index e987e1c..50eddbe 100644 --- a/test/indexing/collection_indexer.jl +++ b/test/indexing/collection_indexer.jl @@ -293,12 +293,13 @@ end # Test 2: Testing types, shapes and range of vals num_partitions = rand(1:1000) - codes = UInt32.(rand(1:num_partitions, 10000)) # Large array with random values + codes = UInt32.(rand(1:num_partitions, 10000)) # Large array with random values ivf, ivf_lengths = _build_ivf(codes, num_partitions) @test length(ivf) == length(codes) @test sum(ivf_lengths) == length(codes) @test length(ivf_lengths) == num_partitions @test all(in(ivf), codes) + @test length(unique(ivf)) == length(ivf) # an eid can occur in atmost one cluster @test ivf isa Vector{Int} @test ivf_lengths isa Vector{Int} end diff --git a/test/runtests.jl b/test/runtests.jl index 1d649cd..c6d6df1 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -25,6 +25,10 @@ include("indexing/collection_indexer.jl") include("modelling/tokenization/tokenizer_utils.jl") include("modelling/embedding_utils.jl") +# search operations +include("searching.jl") +include("search/ranking.jl") + # utils include("utils.jl") diff --git a/test/search/ranking.jl b/test/search/ranking.jl new file mode 100644 index 0000000..5b1ee8a --- /dev/null +++ b/test/search/ranking.jl @@ -0,0 +1,162 @@ +using ColBERT: _cids_to_eids!, retrieve, _collect_compressed_embs_for_pids, + maxsim + +@testset "_cids_to_eids!" begin + # Test 1: Correct conversion + eids = Vector{Int}(undef, 5) + centroid_ids = [2, 1] + ivf = [1, 2, 3, 4, 5, 6] + ivf_lengths = [3, 2, 1] + _cids_to_eids!(eids, centroid_ids, ivf, ivf_lengths) + @test eids == [4, 5, 1, 2, 3] + + # Test 2: Random partitioning over a large vector + # centroid_ids don't have to be sorted + num_embeddings = rand(1:1000) + num_partitions = rand(1:20) + ivf = Int[] + ivf_lengths = zeros(Int, num_partitions) + assignments = rand(1:num_partitions, num_embeddings) # eid to cid + ivf_mapping = [Int[] for centroid_id in 1:num_partitions] + for (eid, assignment) in enumerate(assignments) + push!(ivf_mapping[assignment], eid) + end + for centroid_id in 1:num_partitions + shuffle!(ivf_mapping[centroid_id]) + append!(ivf, ivf_mapping[centroid_id]) + ivf_lengths[centroid_id] = length(ivf_mapping[centroid_id]) + end + centroid_ids = randperm(num_partitions)[1:rand(1:num_partitions)] + eids = Vector{Int}(undef, sum(ivf_lengths[centroid_ids])) + _cids_to_eids!(eids, centroid_ids, ivf, ivf_lengths) + for centroid_id in centroid_ids + @test collect(take(eids, ivf_lengths[centroid_id])) == + ivf_mapping[centroid_id] + eids = eids[(ivf_lengths[centroid_id] + 1):end] + end + + # Test 3: Empty eids + eids = Vector{Int}(undef, 0) + centroid_ids = Int[] + ivf = Int[] + ivf_lengths = Int[] + _cids_to_eids!(eids, centroid_ids, ivf, ivf_lengths) + @test eids == [] + + # Test 4: Empty centroid_ids + eids = Vector{Int}(undef, 0) + centroid_ids = Int[] + ivf = [1, 2, 3, 4, 5, 6] + ivf_lengths = [3, 2, 1] + _cids_to_eids!(eids, centroid_ids, ivf, ivf_lengths) + @test all(iszero, eids) + + # Test 5: eids DimensionMismatch + eids = Vector{Int}(undef, 5) + centroid_ids = [1, 2, 3] + ivf = [1, 2, 3, 4, 5, 6] + ivf_lengths = [2, 2, 2] + @test_throws DimensionMismatch _cids_to_eids!( + eids, centroid_ids, ivf, ivf_lengths) + + # Test 6: ivf and ivf_lengths DimensionMismatch + eids = Vector{Int}(undef, 6) + centroid_ids = [1, 2, 3] + ivf = [1, 2, 3, 4, 5] + ivf_lengths = [2, 2, 2] + @test_throws DimensionMismatch _cids_to_eids!( + eids, centroid_ids, ivf, ivf_lengths) +end + +@testset "retrieve" begin + # Test 1: A small and basic case + # The first and last centroids are closest to the query + ivf = [3, 1, 4, 5, 6, 2] + ivf_lengths = [2, 3, 1] + centroids = Float32[1.0 0.0 0.0; 0.0 0.0 1.0] + emb2pid = [10, 20, 30, 40, 50, 60] + nprobe = 2 + Q = Float32[0.5 0.5]' + expected_pids = [10, 20, 30] + pids = retrieve(ivf, ivf_lengths, centroids, emb2pid, nprobe, Q) + @test pids == expected_pids +end + +@testset "_collect_compressed_embs_for_pids" begin + # Test 1: Small example + doclens = [3, 2, 4] + codes = UInt32[1, 2, 3, 4, 5, 6, 7, 8, 9] + residuals = [0x11 0x12 0x13 0x14 0x15 0x16 0x17 0x18 0x19; + 0x21 0x22 0x23 0x24 0x25 0x26 0x27 0x28 0x29] + pids = [1, 3] + expected_codes_packed = UInt32[1, 2, 3, 6, 7, 8, 9] + expected_residuals_packed = UInt8[0x11 0x12 0x13 0x16 0x17 0x18 0x19; + 0x21 0x22 0x23 0x26 0x27 0x28 0x29] + codes_packed, residuals_packed = _collect_compressed_embs_for_pids( + doclens, codes, residuals, pids) + @test codes_packed == expected_codes_packed + @test residuals_packed == expected_residuals_packed + + # Test 2: Edge case - Empty pids + pids = Int[] + expected_codes_packed = UInt32[] + expected_residuals_packed = zeros(UInt8, 2, 0) + codes_packed, residuals_packed = _collect_compressed_embs_for_pids( + doclens, codes, residuals, pids) + @test codes_packed == expected_codes_packed + @test residuals_packed == expected_residuals_packed + + # Test 3: Edge case - doclens with zero values + doclens = [3, 0, 4] + codes = UInt32[1, 2, 3, 6, 7, 8, 9] + residuals = [0x11 0x12 0x13 0x16 0x17 0x18 0x19; + 0x21 0x22 0x23 0x26 0x27 0x28 0x29] + pids = [1, 3] + expected_codes_packed = UInt32[1, 2, 3, 6, 7, 8, 9] + expected_residuals_packed = UInt8[0x11 0x12 0x13 0x16 0x17 0x18 0x19; + 0x21 0x22 0x23 0x26 0x27 0x28 0x29] + codes_packed, residuals_packed = _collect_compressed_embs_for_pids( + doclens, codes, residuals, pids) + @test codes_packed == expected_codes_packed + @test residuals_packed == expected_residuals_packed + + # Test 4: Shapes and types + num_pids = rand(1:1000) + doclens = rand(1:100, num_pids) + codes = rand(UInt32, sum(doclens)) + residuals = rand(UInt8, 16, sum(doclens)) + pids = rand(1:num_pids, rand(1:num_pids)) + codes_packed, residuals_packed = _collect_compressed_embs_for_pids( + doclens, codes, residuals, pids) + @test length(codes_packed) == sum(doclens[pids]) + @test size(residuals_packed) == (16, sum(doclens[pids])) + @test codes_packed isa Vector{UInt32} + @test residuals_packed isa Matrix{UInt8} +end + +@testset "maxsim" begin + # Test 1: Basic test + Q = Float32[1.0 0.5; 0.5 1.0] # Query matrix (2 query vectors) + D = Float32[0.8 0.3 0.1; 0.2 0.7 0.4] # Document matrix (3 document vectors) + pids = [1, 2] # Document pids to match + doclens = [1, 2] # Number of document vectors per pid + expected_scores = Float32[1.5, 1.5] + scores = maxsim(Q, D, pids, doclens) + @test scores == expected_scores + + # Test 2: Dimension mismatch case + Q = Float32[1.0 0.5; 0.5 1.0] # Query matrix + D = Float32[0.8 0.3] # Document matrix with incorrect dimension + pids = [1, 2] # pids to match + doclens = [1, 2] # Document lengths + @test_throws DimensionMismatch maxsim(Q, D, pids, doclens) + + # Test 3: Shapes and types + doclens = rand(1:10, 1000) + Q = rand(Float32, 128, 100) # 100 query vectors, each of size 128 + D = rand(Float32, 128, sum(doclens)) # document vectors, each of size 128 + pids = collect(1:1000) + scores = maxsim(Q, D, pids, doclens) + @test length(scores) == length(pids) + @test scores isa Vector{Float32} +end diff --git a/test/searching.jl b/test/searching.jl new file mode 100644 index 0000000..20868ec --- /dev/null +++ b/test/searching.jl @@ -0,0 +1,43 @@ +using ColBERT: _build_emb2pid + +@testset "_build_emb2pid" begin + # Test 1: A single document + doclens = rand(1:1000, 1) + emb2pid = _build_emb2pid(doclens) + @test emb2pid == ones(Int, doclens[1]) + + # Test 2: Small test with a custom output + doclens = [3, 2, 4] + emb2pid = _build_emb2pid(doclens) + @test emb2pid == [1, 1, 1, 2, 2, 3, 3, 3, 3] + + # Test 3: With some zero document lengths + doclens = [0, 2, 0, 3] + emb2pid = _build_emb2pid(doclens) + @test emb2pid == [2, 2, 4, 4, 4] + + # Test 3: Large, random inputs with equal doclengths + doclen = rand(1:1000) + doclens = doclen * ones(Int, rand(1:500)) + emb2pid = _build_emb2pid(doclens) + @test emb2pid == repeat(1:length(doclens), inner = doclen) + + # Test 4: with no documents + doclens = Int[] + emb2pid = _build_emb2pid(doclens) + @test emb2pid == Int[] + + # Test 5: Range of values, shapes and type + doclens = rand(0:100, rand(1:500)) + non_zero_docs = findall(>(0), doclens) + zero_docs = findall(==(0), doclens) + emb2pid = _build_emb2pid(doclens) + @test all(in(non_zero_docs), emb2pid) + @test issorted(emb2pid) + for pid in non_zero_docs + @test count(==(pid), emb2pid) == doclens[pid] + end + @test length(emb2pid) == sum(doclens[non_zero_docs]) + @test emb2pid isa Vector{Int} +end + From 74dd70ad80fb808492af90223042397c8ac87312 Mon Sep 17 00:00:00 2001 From: Siddhant Chaudhary Date: Tue, 17 Sep 2024 10:45:14 +0530 Subject: [PATCH 2/3] Removing jldoctest. --- src/infra/config.jl | 3 +-- src/loaders.jl | 2 +- src/savers.jl | 2 +- 3 files changed, 3 insertions(+), 4 deletions(-) diff --git a/src/infra/config.jl b/src/infra/config.jl index 8bffd92..528e685 100644 --- a/src/infra/config.jl +++ b/src/infra/config.jl @@ -41,7 +41,7 @@ A [`ColBERTConfig`](@ref) object. Most users will just want to use the defaults for most settings. Here's a minimal example: -```jldoctest +```julia-repl julia> using ColBERT; julia> config = ColBERTConfig( @@ -49,7 +49,6 @@ julia> config = ColBERTConfig( collection = "/home/codetalker7/documents", index_path = "./local_index" ); - ``` """ Base.@kwdef struct ColBERTConfig diff --git a/src/loaders.jl b/src/loaders.jl index 5af5d0a..d79659b 100644 --- a/src/loaders.jl +++ b/src/loaders.jl @@ -48,7 +48,7 @@ Load a [`ColBERTConfig`](@ref) from disk. # Examples -```jldoctest +```julia-repl julia> using ColBERT; julia> config = ColBERTConfig( diff --git a/src/savers.jl b/src/savers.jl index b3b1453..8a1fb03 100644 --- a/src/savers.jl +++ b/src/savers.jl @@ -93,7 +93,7 @@ Save a [`ColBERTConfig`](@ref) to disk in JSON. # Examples -```jldoctest +```julia-repl julia> using ColBERT; julia> config = ColBERTConfig( From 5484e6c6a080af7117ca206a6e8ca7a658f32f92 Mon Sep 17 00:00:00 2001 From: Siddhant Chaudhary Date: Tue, 17 Sep 2024 10:48:23 +0530 Subject: [PATCH 3/3] Updating the compat helper. --- .github/workflows/CompatHelper.yml | 37 +++++++++++-- src/indexing.jl | 4 +- src/indexing/collection_indexer.jl | 79 +++++++++++++++++---------- src/infra/config.jl | 4 +- src/modelling/checkpoint.jl | 87 +++++++++++++----------------- src/savers.jl | 11 ++-- test/search/ranking.jl | 4 +- test/searching.jl | 13 +++-- 8 files changed, 137 insertions(+), 102 deletions(-) diff --git a/.github/workflows/CompatHelper.yml b/.github/workflows/CompatHelper.yml index cba9134..0918161 100644 --- a/.github/workflows/CompatHelper.yml +++ b/.github/workflows/CompatHelper.yml @@ -3,14 +3,43 @@ on: schedule: - cron: 0 0 * * * workflow_dispatch: +permissions: + contents: write + pull-requests: write jobs: CompatHelper: runs-on: ubuntu-latest steps: - - name: Pkg.add("CompatHelper") - run: julia -e 'using Pkg; Pkg.add("CompatHelper")' - - name: CompatHelper.main() + - name: Check if Julia is already available in the PATH + id: julia_in_path + run: which julia + continue-on-error: true + - name: Install Julia, but only if it is not already available in the PATH + uses: julia-actions/setup-julia@v1 + with: + version: '1' + arch: ${{ runner.arch }} + if: steps.julia_in_path.outcome != 'success' + - name: "Add the General registry via Git" + run: | + import Pkg + ENV["JULIA_PKG_SERVER"] = "" + Pkg.Registry.add("General") + shell: julia --color=yes {0} + - name: "Install CompatHelper" + run: | + import Pkg + name = "CompatHelper" + uuid = "aa819f21-2bde-4658-8897-bab36330d9b7" + version = "3" + Pkg.add(; name, uuid, version) + shell: julia --color=yes {0} + - name: "Run CompatHelper" + run: | + import CompatHelper + CompatHelper.main() + shell: julia --color=yes {0} env: GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} COMPATHELPER_PRIV: ${{ secrets.DOCUMENTER_KEY }} - run: julia -e 'using CompatHelper; CompatHelper.main()' + # COMPATHELPER_PRIV: ${{ secrets.COMPATHELPER_PRIV }} diff --git a/src/indexing.jl b/src/indexing.jl index b80091e..a75fd5e 100644 --- a/src/indexing.jl +++ b/src/indexing.jl @@ -18,8 +18,8 @@ Type representing an ColBERT indexer. # Returns -An [`Indexer`] wrapping a [`ColBERTConfig`](@ref), a [`Checkpoint`](@ref) and -a collection of documents to index. +An [`Indexer`] wrapping a [`ColBERTConfig`](@ref) along with the trained ColBERT +model. """ function Indexer(config::ColBERTConfig) tokenizer, bert, linear = load_hgf_pretrained_local(config.checkpoint) diff --git a/src/indexing/collection_indexer.jl b/src/indexing/collection_indexer.jl index f885019..dc19e72 100644 --- a/src/indexing/collection_indexer.jl +++ b/src/indexing/collection_indexer.jl @@ -24,8 +24,10 @@ function _sample_pids(num_documents::Int) end """ - _sample_embeddings(config::ColBERTConfig, checkpoint::Checkpoint, - collection::Vector{String}, sampled_pids::Set{Int}) + _sample_embeddings(bert::HF.HGFBertModel, linear::Layers.Dense, + tokenizer::TextEncoders.AbstractTransformerTextEncoder, + dim::Int, index_bsize::Int, doc_token::String, + skiplist::Vector{Int}, collection::Vector{String}) Compute embeddings for the PIDs sampled by [`_sample_pids`](@ref). @@ -35,14 +37,18 @@ total number of embeddings over all documents. # Arguments - - `config`: The [`ColBERTConfig`](@ref) to be used. - - `checkpoint`: The [`Checkpoint`] used to encode the passages. + - `bert`: The pre-trained BERT component of ColBERT. + - `linear`: The pre-trained linear component of ColBERT. + - `tokenizer`: The tokenizer to be used. + - `dim`: The embedding dimension. + - `index_bsize`: The batch size to be used to run the transformer. See [`ColBERTConfig`](@ref). + - `doc_token`: The document token. See [`ColBERTConfig`](@ref). + - `skiplist`: List of tokens to skip. - `collection`: The underlying collection of passages to get the samples from. - - `sampled_pids`: Set of PIDs sampled by [`_sample_pids`](@ref). # Returns -A `Dict` containing the average document length (i.e number of attended tokens) computed +A tuple containing the average document length (i.e number of attended tokens) computed from the sampled documents, and the embedding matrix for the local samples. The matrix has shape `(D, N)`, where `D` is the embedding dimension (`128`) and `N` is the total number of embeddings over all the sampled passages. @@ -85,24 +91,22 @@ function _heldout_split( end """ - setup(config::ColBERTConfig, checkpoint::Checkpoint, collection::Vector{String}) + setup(collection::Vector{String}, avg_doclen_est::Float32, + num_clustering_embs::Int, chunksize::Union{Missing, Int}, nranks::Int) -Initialize the index by computing some indexing-specific estimates and save the indexing plan -to disk. +Initialize the index by computing some indexing-specific estimates and the index plan. The number of chunks into which the document embeddings will be stored is simply computed using -the number of documents and the size of a chunk. A bunch of pids used for initializing the -centroids for the embedding clusters are sampled using the [`_sample_pids`](@ref) -and [`_sample_embeddings`](@ref) functions, and these samples are used to calculate the -average document lengths and the estimated number of embeddings which will be computed across -all documents. Finally, the number of clusters to be used for indexing is computed, and is -proportional to ``16\\sqrt{\\text{Estimated number of embeddings}}``. +the number of documents and the size of a chunk. The number of clusters to be used for indexing +is computed, and is proportional to ``16\\sqrt{\\text{Estimated number of embeddings}}``. # Arguments - - `config`: The [`ColBERTConfig`](@ref) being used to set up the indexing. - - `checkpoint`: The [`Checkpoint`](@ref) used to compute embeddings. - - `collection`: The underlying collection of passages to initialize the index for. + - `collection`: The collection of documents to index. + - `avg_doclen_est`: The collection of documents to index. + - `num_clustering_embs`: The number of embeddings to be used for computing the clusters. + - `chunksize`: The size of a chunk to be used. Can be `Missing`. + - `nranks`: Number of GPUs. Currently this can only be `1`. # Returns @@ -148,9 +152,9 @@ function _bucket_cutoffs_and_weights( end """ - _compute_avg_residuals( + _compute_avg_residuals!( nbits::Int, centroids::AbstractMatrix{Float32}, - heldout::AbstractMatrix{Float32}) + heldout::AbstractMatrix{Float32}, codes::AbstractVector{UInt32}) Compute the average residuals and other statistics of the held-out sample embeddings. @@ -162,7 +166,8 @@ Compute the average residuals and other statistics of the held-out sample embedd where `D` is the embedding dimension (`128`) and `indexer.num_partitions` is the number of clusters. - `heldout`: A matrix containing the held-out embeddings, computed using - [`_concatenate_and_split_sample`](@ref). + `_heldout_split`. + - `codes`: The array used to store the codes for each heldout embedding. # Returns @@ -196,7 +201,7 @@ end Compute centroids using a ``k``-means clustering algorithn, and store the compression information on disk. -Average residuals and other compression data is computed via the [`_compute_avg_residuals`](@ref) +Average residuals and other compression data is computed via the `_compute_avg_residuals`. function. # Arguments @@ -232,20 +237,36 @@ function train( end """ - index(config::ColBERTConfig, checkpoint::Checkpoint, collection::Vector{String}) + index(index_path::String, bert::HF.HGFBertModel, linear::Layers.Dense, + tokenizer::TextEncoders.AbstractTransformerTextEncoder, + collection::Vector{String}, dim::Int, index_bsize::Int, + doc_token::String, skiplist::Vector{Int}, num_chunks::Int, + chunksize::Int, centroids::AbstractMatrix{Float32}, + bucket_cutoffs::AbstractVector{Float32}, nbits::Int) -Build the index using `indexer`. +Build the index using for the `collection`. -The documents are processed in batches of size `chunksize`, determined by the config -(see [`ColBERTConfig`](@ref) and [`setup`](@ref)). Embeddings and document lengths are -computed for each batch (see [`encode_passages`](@ref)), and they are saved to disk +The documents are processed in batches of size `chunksize` (see [`setup`](@ref)). +Embeddings and document lengths are computed for each batch +(see [`encode_passages`](@ref)), and they are saved to disk along with relevant metadata (see [`save_chunk`](@ref)). # Arguments - - `config`: The [`ColBERTConfig`](@ref) being used. - - `checkpoint`: The [`Checkpoint`](@ref) to compute embeddings. + - `index_path`: Path where the index is to be saved. + - `bert`: The pre-trained BERT component of the ColBERT model. + - `linear`: The pre-trained linear component of the ColBERT model. + - `tokenizer`: Tokenizer to be used. - `collection`: The collection to index. + - `dim`: The embedding dimension. + - `index_bsize`: The batch size used for running the transformer. + - `doc_token`: The document token. + - `skiplist`: List of tokens to skip. + - `num_chunks`: Total number of chunks. + - `chunksize`: The maximum size of a chunk. + - `centroids`: Centroids used to compute the compressed representations. + - `bucket_cutoffs`: Cutoffs used to compute the residuals. + - `nbits`: Number of bits to encode the residuals in. """ function index(index_path::String, bert::HF.HGFBertModel, linear::Layers.Dense, tokenizer::TextEncoders.AbstractTransformerTextEncoder, diff --git a/src/infra/config.jl b/src/infra/config.jl index 528e685..5ff3aa8 100644 --- a/src/infra/config.jl +++ b/src/infra/config.jl @@ -30,8 +30,8 @@ Structure containing config for running and training various components. - `passages_batch_size`: The number of passages sent as a batch to encoding functions. Default is `300`. - `nbits`: Number of bits used to compress residuals. - `kmeans_niters`: Number of iterations used for k-means clustering. - - `nprobe`: The number of nearest centroids to fetch during a search. Default is `2`. Also see [`retrieve`](@ref). - - `ncandidates`: The number of candidates to get during candidate generation in search. Default is `8192`. Also see [`retrieve`](@ref). + - `nprobe`: The number of nearest centroids to fetch during a search. Default is `2`. Also see `retrieve`. + - `ncandidates`: The number of candidates to get during candidate generation in search. Default is `8192`. Also see `retrieve`. # Returns diff --git a/src/modelling/checkpoint.jl b/src/modelling/checkpoint.jl index da61a10..3e16cec 100644 --- a/src/modelling/checkpoint.jl +++ b/src/modelling/checkpoint.jl @@ -1,52 +1,22 @@ """ - doc( - config::ColBERTConfig, checkpoint::Checkpoint, integer_ids::AbstractMatrix{Int32}, - integer_mask::AbstractMatrix{Bool}) + doc(bert::HF.HGFBertModel, linear::Layers.Dense, + integer_ids::AbstractMatrix{Int32}, bitmask::AbstractMatrix{Bool}) Compute the hidden state of the BERT and linear layers of ColBERT for documents. # Arguments - - `config`: The [`ColBERTConfig`](@ref) being used. - - `checkpoint`: The [`Checkpoint`](@ref) containing the layers to compute the embeddings. + - `bert`: The pre-trained BERT component of the ColBERT model. + - `linear`: The pre-trained linear component of the ColBERT model. - `integer_ids`: An array of token IDs to be fed into the BERT model. - `integer_mask`: An array of corresponding attention masks. Should have the same shape as `integer_ids`. # Returns -A tuple `D, mask`, where: - - - `D` is an array containing the normalized embeddings for each token in each document. - It has shape `(D, L, N)`, where `D` is the embedding dimension (`128` for the linear layer - of ColBERT), and `(L, N)` is the shape of `integer_ids`, i.e `L` is the maximum length of - any document and `N` is the total number of documents. - - `mask` is an array containing attention masks for all documents, after masking out any - tokens in the `skiplist` of `checkpoint`. It has shape `(1, L, N)`, where `(L, N)` - is the same as described above. - -# Examples - -Continuing from the example in [`tensorize_docs`](@ref) and [`Checkpoint`](@ref): - -```julia-repl -julia> integer_ids, integer_mask = batches[1] - -julia> D, mask = ColBERT.doc(config, checkpoint, integer_ids, integer_mask); - -julia> typeof(D), size(D) -(CuArray{Float32, 3, CUDA.DeviceMemory}, (128, 21, 3)) - -julia> mask -1×21×3 CuArray{Bool, 3, CUDA.DeviceMemory}: -[:, :, 1] = - 1 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 - -[:, :, 2] = - 1 1 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 - -[:, :, 3] = - 1 1 1 1 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 -``` +An array `D` containing the normalized embeddings for each token in each document. +It has shape `(D, L, N)`, where `D` is the embedding dimension (`128` for the linear layer +of ColBERT), and `(L, N)` is the shape of `integer_ids`, i.e `L` is the maximum length of +any document and `N` is the total number of documents. """ function doc(bert::HF.HGFBertModel, linear::Layers.Dense, integer_ids::AbstractMatrix{Int32}, bitmask::AbstractMatrix{Bool}) @@ -101,20 +71,26 @@ function _query_embeddings( end """ - encode_passages( - config::ColBERTConfig, checkpoint::Checkpoint, passages::Vector{String}) + encode_passages(bert::HF.HGFBertModel, linear::Layers.Dense, + tokenizer::TextEncoders.AbstractTransformerTextEncoder, + passages::Vector{String}, dim::Int, index_bsize::Int, + doc_token::String, skiplist::Vector{Int}) -Encode a list of passages using `checkpoint`. +Encode a list of document passages. The given `passages` are run through the underlying BERT model and the linear layer to generate the embeddings, after doing relevant document-specific preprocessing. -See [`docFromText`](@ref) for more details. # Arguments - - `config`: The [`ColBERTConfig`](@ref) to be used. - - `checkpoint`: The [`Checkpoint`](@ref) used to encode the passages. + - `bert`: The pre-trained BERT component of the ColBERT model. + - `linear`: The pre-trained linear component of the ColBERT model. + - `tokenizer`: The tokenizer to be used. - `passages`: A list of strings representing the passages to be encoded. + - `dim`: The embedding dimension. + - `index_bsize`: The batch size to be used for running the transformer. + - `doc_token`: The document token. + - `skiplist`: A list of tokens to skip. # Returns @@ -213,23 +189,32 @@ function encode_passages(bert::HF.HGFBertModel, linear::Layers.Dense, end """ - encode_query(searcher::Searcher, query::String) + encode_queries(bert::HF.HGFBertModel, linear::Layers.Dense, + tokenizer::TextEncoders.AbstractTransformerTextEncoder, + queries::Vector{String}, dim::Int, + index_bsize::Int, query_token::String, attend_to_mask_tokens::Bool, + skiplist::Vector{Int}) -Encode a search query to a matrix of embeddings using the provided `searcher`. The encoded query can then be used to search the collection. +Encode a list of query passages. # Arguments - - `searcher`: A Searcher object that contains information about the collection and the index. - - `query`: The search query to encode. + - `bert`: The pre-trained BERT component of the ColBERT model. + - `linear`: The pre-trained linear component of the ColBERT model. + - `tokenizer`: The tokenizer to be used. + - `queries`: A list of strings representing the queries to be encoded. + - `dim`: The embedding dimension. + - `index_bsize`: The batch size to be used for running the transformer. + - `query_token`: The query token. +- `attend_to_mask_tokens`: Whether to attend to `"[MASK]"` tokens. + - `skiplist`: A list of tokens to skip. # Returns -An array containing the embeddings for each token in the query. Also see [queryFromText](@ref) to see the size of the array. +An array containing the embeddings for each token in the query. # Examples -Here's an example using the `config` and `checkpoint` from the example for [`Checkpoint`](@ref). - ```julia-repl julia> using ColBERT: load_hgf_pretrained_local, ColBERTConfig, encode_queries; diff --git a/src/savers.jl b/src/savers.jl index 8a1fb03..d8a6a1f 100644 --- a/src/savers.jl +++ b/src/savers.jl @@ -11,7 +11,7 @@ Save compression/decompression information from the index path. - `centroids`: The matrix of centroids of the index. - `bucket_cutoffs`: Cutoffs used to determine buckets during residual compression. - `bucket_weights`: Weights used to determine the decompressed values during decompression. - - `avg_residual`: The average residual value, computed from the heldout set (see [`_compute_avg_residuals`](@ref)). + - `avg_residual`: The average residual value, computed from the heldout set (see `_compute_avg_residuals`). """ function save_codec( index_path::String, centroids::Matrix{Float32}, bucket_cutoffs::Vector{Float32}, @@ -30,8 +30,8 @@ end """ save_chunk( - config::ColBERTConfig, codec::Dict, chunk_idx::Int, passage_offset::Int, - embs::AbstractMatrix{Float32}, doclens::AbstractVector{Int}) + index_path::String, codes::AbstractVector{UInt32}, residuals::AbstractMatrix{UInt8}, + chunk_idx::Int, passage_offset::Int, doclens::AbstractVector{Int}) Save a single chunk of compressed embeddings and their relevant metadata to disk. @@ -42,10 +42,11 @@ number of embeddings and the passage offsets are saved in a file named `(0), doclens) - zero_docs = findall(==(0), doclens) + doclens = rand(0:100, rand(1:500)) + non_zero_docs = findall(>(0), doclens) + zero_docs = findall(==(0), doclens) emb2pid = _build_emb2pid(doclens) @test all(in(non_zero_docs), emb2pid) @test issorted(emb2pid) for pid in non_zero_docs @test count(==(pid), emb2pid) == doclens[pid] end - @test length(emb2pid) == sum(doclens[non_zero_docs]) + @test length(emb2pid) == sum(doclens[non_zero_docs]) @test emb2pid isa Vector{Int} end -