diff --git a/README.md b/README.md index 01f066e..4612677 100644 --- a/README.md +++ b/README.md @@ -20,6 +20,7 @@ fast**RAG** is a research framework designed to facilitate the building of retri ## Updates +- **June 2023**: ColBERT index modification: adding/removing documents; see [IndexUpdater](libs/colbert/colbert/index_updater.py). - **May 2023**: [RAG with LLM and dynamic prompt synthesis example](examples/rag-prompt-hf.ipynb). - **April 2023**: Qdrant `DocumentStore` support. diff --git a/fastrag/__init__.py b/fastrag/__init__.py index 943fa38..0cc09fb 100644 --- a/fastrag/__init__.py +++ b/fastrag/__init__.py @@ -4,7 +4,7 @@ from fastrag import image_generators, kg_creators, rankers, readers, retrievers, stores from fastrag.utils import add_timing_to_pipeline -__version__ = "1.2.0" +__version__ = "1.3.0" def load_pipeline(config_path: str) -> Pipeline: diff --git a/libs/colbert/README.md b/libs/colbert/README.md index 0226e53..d335364 100644 --- a/libs/colbert/README.md +++ b/libs/colbert/README.md @@ -1,3 +1,8 @@ +## 🚨 **Announcements** + +* (1/29/23) We have merged a new index updater feature and support for additional Hugging Face models! These are in beta so please give us feedback as you try them out. +* (1/24/23) If you're looking for the **DSP** framework for composing ColBERTv2 and LLMs, it's at: https://github.com/stanfordnlp/dsp + # ColBERT (v2) ### ColBERT is a _fast_ and _accurate_ retrieval model, enabling scalable BERT-based search over large text collections in tens of milliseconds. @@ -18,7 +23,7 @@ These rich interactions allow ColBERT to surpass the quality of _single-vector_ * [**Relevance-guided Supervision for OpenQA with ColBERT**](https://arxiv.org/abs/2007.00814) (TACL'21). * [**Baleen: Robust Multi-Hop Reasoning at Scale via Condensed Retrieval**](https://arxiv.org/abs/2101.00436) (NeurIPS'21). * [**ColBERTv2: Effective and Efficient Retrieval via Lightweight Late Interaction**](https://arxiv.org/abs/2112.01488) (NAACL'22). -* [**PLAID: An Efficient Engine for Late Interaction Retrieval**](https://arxiv.org/abs/2205.09707) (preprint). +* [**PLAID: An Efficient Engine for Late Interaction Retrieval**](https://arxiv.org/abs/2205.09707) (CIKM'22). ---- @@ -29,7 +34,7 @@ The ColBERTv1 code from the SIGIR'20 paper is in the [`colbertv1` branch](https: ## Installation -ColBERT requires Python 3.7+ and Pytorch 1.9+ and uses the [HuggingFace Transformers](https://github.com/huggingface/transformers) library. +ColBERT requires Python 3.7+ and Pytorch 1.9+ and uses the [Hugging Face Transformers](https://github.com/huggingface/transformers) library. We strongly recommend creating a conda environment using the commands below. (If you don't have conda, follow the official [conda installation guide](https://docs.anaconda.com/anaconda/install/linux/#installation).) @@ -161,6 +166,19 @@ if __name__=='__main__': print(f"Saved checkpoint to {checkpoint_path}...") ``` +## Running a lightweight ColBERTv2 server +We provide a script to run a lightweight server which serves k (upto 100) results in ranked order for a given search query, in JSON format. This script can be used to power DSP programs. + +To run the server, update the environment variables `INDEX_ROOT` and `INDEX_NAME` in the `.env` file to point to the appropriate ColBERT index. The run the following command: +``` +python server.py +``` + +A sample query: +``` +http://localhost:8893/api/search?query=Who won the 2022 FIFA world cup&k=25 +``` + ## Branches ### Supported branches diff --git a/libs/colbert/colbert/__init__.py b/libs/colbert/colbert/__init__.py index 5bb580d..3e8bd88 100644 --- a/libs/colbert/colbert/__init__.py +++ b/libs/colbert/colbert/__init__.py @@ -1,3 +1,4 @@ +from .index_updater import IndexUpdater from .indexer import Indexer from .modeling.checkpoint import Checkpoint from .searcher import Searcher diff --git a/libs/colbert/colbert/distillation/ranking_scorer.py b/libs/colbert/colbert/distillation/ranking_scorer.py new file mode 100644 index 0000000..1787457 --- /dev/null +++ b/libs/colbert/colbert/distillation/ranking_scorer.py @@ -0,0 +1,50 @@ +from collections import defaultdict + +import tqdm +import ujson +from colbert.data import Ranking +from colbert.distillation.scorer import Scorer +from colbert.infra import Run +from colbert.infra.provenance import Provenance +from colbert.utility.utils.save_metadata import get_metadata_only +from colbert.utils.utils import print_message, zipstar + + +class RankingScorer: + def __init__(self, scorer: Scorer, ranking: Ranking): + self.scorer = scorer + self.ranking = ranking.tolist() + self.__provenance = Provenance() + + print_message(f"#> Loaded ranking with {len(self.ranking)} qid--pid pairs!") + + def provenance(self): + return self.__provenance + + def run(self): + print_message(f"#> Starting..") + + qids, pids, *_ = zipstar(self.ranking) + distillation_scores = self.scorer.launch(qids, pids) + + scores_by_qid = defaultdict(list) + + for qid, pid, score in tqdm.tqdm(zip(qids, pids, distillation_scores)): + scores_by_qid[qid].append((score, pid)) + + with Run().open("distillation_scores.json", "w") as f: + for qid in tqdm.tqdm(scores_by_qid): + obj = (qid, scores_by_qid[qid]) + f.write(ujson.dumps(obj) + "\n") + + output_path = f.name + print_message(f"#> Saved the distillation_scores to {output_path}") + + with Run().open(f"{output_path}.meta", "w") as f: + d = {} + d["metadata"] = get_metadata_only() + d["provenance"] = self.provenance() + line = ujson.dumps(d, indent=4) + f.write(line) + + return output_path diff --git a/libs/colbert/colbert/distillation/scorer.py b/libs/colbert/colbert/distillation/scorer.py new file mode 100644 index 0000000..0c634b4 --- /dev/null +++ b/libs/colbert/colbert/distillation/scorer.py @@ -0,0 +1,75 @@ +import torch +import tqdm +from colbert.infra import Run, RunConfig +from colbert.infra.launcher import Launcher +from colbert.modeling.reranker.electra import ElectraReranker +from colbert.utils.utils import flatten +from transformers import AutoModelForSequenceClassification, AutoTokenizer + +DEFAULT_MODEL = "cross-encoder/ms-marco-MiniLM-L-6-v2" + + +class Scorer: + def __init__(self, queries, collection, model=DEFAULT_MODEL, maxlen=180, bsize=256): + self.queries = queries + self.collection = collection + self.model = model + + self.maxlen = maxlen + self.bsize = bsize + + def launch(self, qids, pids): + launcher = Launcher(self._score_pairs_process, return_all=True) + outputs = launcher.launch(Run().config, qids, pids) + + return flatten(outputs) + + def _score_pairs_process(self, config, qids, pids): + assert len(qids) == len(pids), (len(qids), len(pids)) + share = 1 + len(qids) // config.nranks + offset = config.rank * share + endpos = (1 + config.rank) * share + + return self._score_pairs( + qids[offset:endpos], pids[offset:endpos], show_progress=(config.rank < 1) + ) + + def _score_pairs(self, qids, pids, show_progress=False): + tokenizer = AutoTokenizer.from_pretrained(self.model) + model = AutoModelForSequenceClassification.from_pretrained(self.model).cuda() + + assert len(qids) == len(pids), (len(qids), len(pids)) + + scores = [] + + model.eval() + with torch.inference_mode(): + with torch.cuda.amp.autocast(): + for offset in tqdm.tqdm( + range(0, len(qids), self.bsize), disable=(not show_progress) + ): + endpos = offset + self.bsize + + queries_ = [self.queries[qid] for qid in qids[offset:endpos]] + passages_ = [self.collection[pid] for pid in pids[offset:endpos]] + + features = tokenizer( + queries_, + passages_, + padding="longest", + truncation=True, + return_tensors="pt", + max_length=self.maxlen, + ).to(model.device) + + scores.append(model(**features).logits.flatten()) + + scores = torch.cat(scores) + scores = scores.tolist() + + Run().print(f"Returning with {len(scores)} scores") + + return scores + + +# LONG-TERM TODO: This can be sped up by sorting by length in advance. diff --git a/libs/colbert/colbert/evaluation/loaders.py b/libs/colbert/colbert/evaluation/loaders.py index d0e7ec0..e6f70f3 100644 --- a/libs/colbert/colbert/evaluation/loaders.py +++ b/libs/colbert/colbert/evaluation/loaders.py @@ -176,8 +176,7 @@ def load_collection(collection_path): print(f"{line_idx // 1000 // 1000}M", end=" ", flush=True) pid, passage, *rest = line.strip("\n\r ").split("\t") - # id could be either "id" (the first line), a number or have the format "docNUM" - assert pid == "id" or int(pid if pid.isnumeric() else pid[3:]) == line_idx + assert pid == "id" or int(pid) == line_idx, f"pid={pid}, line_idx={line_idx}" if len(rest) >= 1: title = rest[0] diff --git a/libs/colbert/colbert/index_updater.py b/libs/colbert/colbert/index_updater.py new file mode 100644 index 0000000..e4130b8 --- /dev/null +++ b/libs/colbert/colbert/index_updater.py @@ -0,0 +1,481 @@ +import os + +import numpy as np +import torch +import tqdm +import ujson +from colbert.data import Collection +from colbert.indexing.codecs.residual import ResidualCodec +from colbert.indexing.codecs.residual_embeddings import ResidualEmbeddings +from colbert.indexing.codecs.residual_embeddings_strided import ResidualEmbeddingsStrided +from colbert.indexing.collection_encoder import CollectionEncoder +from colbert.indexing.index_saver import IndexSaver +from colbert.indexing.utils import optimize_ivf +from colbert.modeling.checkpoint import Checkpoint +from colbert.search.index_loader import IndexLoader +from colbert.search.strided_tensor import StridedTensor +from colbert.utils.utils import batch, dotdict, flatten, lengths2offsets, print_message + +# For testing writing into new chunks, can set DEFAULT_CHUNKSIZE smaller (e.g. 1 or 2) +DEFAULT_CHUNKSIZE = 25000 + + +class IndexUpdater: + + """ + IndexUpdater takes in a searcher and adds/remove passages from the searcher. + A checkpoint for passage-encoding must be provided for adding passages. + IndexUpdater can also persist the change of passages to the index on disk. + + Sample usage: + + index_updater = IndexUpdater(config, searcher, checkpoint) + + added_pids = index_updater.add(passages) # all passages added to searcher with their pids returned + index_updater.remove(pids) # all pid within pids removed from searcher + + searcher.search() # the search now reflects the added & removed passages + + index_updater.persist_to_disk() # added & removed passages persisted to index on disk + searcher.Searcher(index, config) # if we reload the searcher now from disk index, the changes we made persists + + """ + + def __init__(self, config, searcher, checkpoint=None): + self.config = config + self.searcher = searcher + self.index_path = searcher.index + + self.has_checkpoint = False + if checkpoint: + self.has_checkpoint = True + self.checkpoint = Checkpoint(checkpoint, config) + self.encoder = CollectionEncoder(config, self.checkpoint) + + self._load_disk_ivf() + + # variables to track removal / append of passages + self.removed_pids = [] + self.first_new_emb = torch.sum(self.searcher.ranker.doclens).item() + self.first_new_pid = len(self.searcher.ranker.doclens) + + def remove(self, pids): + """ + Input: + pids: list(int) + Return: None + + Removes a list of pids from the searcher, + these pids will no longer apppear in future searches with this searcher + to erase passage data from index, call persist_to_disk() after calling remove() + """ + print_message(f"#> Removing pids: {pids}...") + self._remove_pid_from_ivf(pids) + self.removed_pids.extend(pids) + + def add(self, passages): + """ + Input: + passages: list(string) + Output: + passage_ids: list(int) + + Adds new passages to the searcher, + to add passages to the index, call persist_to_disk() after calling add() + """ + if not self.has_checkpoint: + raise ValueError("No checkpoint was provided at IndexUpdater initialization.") + + # Find pid for the first added passage + start_pid = len(self.searcher.ranker.doclens) + curr_pid = start_pid + + # Extend doclens and embs of self.searcher.ranker + embs, doclens = self.encoder.encode_passages(passages) + compressed_embs = self.searcher.ranker.codec.compress(embs) + + # Update searcher + # NOTE: For codes and residuals, the tensors end with padding of length 512, + # hence we concatenate the new appendage in front of the padding + self.searcher.ranker.embeddings.codes = torch.cat( + ( + self.searcher.ranker.embeddings.codes[:-512], + compressed_embs.codes, + self.searcher.ranker.embeddings.codes[-512:], + ) + ) + self.searcher.ranker.embeddings.residuals = torch.cat( + ( + self.searcher.ranker.embeddings.residuals[:-512], + compressed_embs.residuals, + self.searcher.ranker.embeddings.residuals[-512:], + ), + dim=0, + ) + + self.searcher.ranker.doclens = torch.cat( + (self.searcher.ranker.doclens, torch.tensor(doclens)) + ) + + # Build partitions for each pid and update IndexUpdater's current ivf + start = 0 + for doclen in doclens: + end = start + doclen + codes = compressed_embs.codes[start:end] + partitions, _ = self._build_passage_partitions(codes) + self._add_pid_to_ivf(partitions, curr_pid) + + start = end + curr_pid += 1 + + assert start == sum(doclens) + + # Update new ivf in searcher + new_ivf_tensor = StridedTensor(self.curr_ivf, self.curr_ivf_lengths, use_gpu=False) + assert new_ivf_tensor != self.searcher.ranker.ivf + self.searcher.ranker.ivf = new_ivf_tensor + + # Rebuild StridedTensor within searcher + self.searcher.ranker.embeddings_strided = ResidualEmbeddingsStrided( + self.searcher.ranker.codec, + self.searcher.ranker.embeddings, + self.searcher.ranker.doclens, + ) + + print_message(f"#> Added {len(passages)} passages from pid {start_pid}.") + new_pids = list(range(start_pid, start_pid + len(passages))) + return new_pids + + def persist_to_disk(self): + """ + Persist all previous stored changes in IndexUpdater to index on disk, + changes include all calls to IndexUpdater.remove() and IndexUpdater.add() + before persist_to_disk() is called. + """ + + print_message("#> Persisting index changes to disk") + + # Propagate all removed passages to disk + self._load_metadata() + for pid in self.removed_pids: + self._remove_passage_from_disk(pid) + + # Propagate all added passages to disk + # Rationale: keep record of all added passages in IndexUpdater.searcher, + # divide passages into chunks and create / write chunks here + + self._load_metadata() # Reload after removal + + # Calculate avg number of passages per chunk + curr_num_chunks = self.metadata["num_chunks"] + last_chunk_metadata = self._load_chunk_metadata(curr_num_chunks - 1) + if curr_num_chunks == 1: + avg_chunksize = DEFAULT_CHUNKSIZE + else: + avg_chunksize = last_chunk_metadata["passage_offset"] / (curr_num_chunks - 1) + print_message(f"#> Current average chunksize is: {avg_chunksize}.") + + # Calculate number of additional passages we can write to the last chunk + last_chunk_capacity = max(0, avg_chunksize - last_chunk_metadata["num_passages"]) + print_message(f"#> The last chunk can hold {last_chunk_capacity} additional passages.") + + # Find the first and last passages to be persisted + pid_start = self.first_new_pid + emb_start = self.first_new_emb + pid_last = len(self.searcher.ranker.doclens) + emb_last = emb_start + torch.sum(self.searcher.ranker.doclens[pid_start:]).item() + + # First populate the last chunk + if last_chunk_capacity > 0: + pid_end = min(pid_last, pid_start + last_chunk_capacity) + emb_end = emb_start + torch.sum(self.searcher.ranker.doclens[pid_start:pid_end]).item() + + # Write to last chunk + self._write_to_last_chunk(pid_start, pid_end, emb_start, emb_end) + pid_start = pid_end + emb_start = emb_end + + # Then create new chunks to hold the remaining added passages + while pid_start < pid_last: + pid_end = min(pid_last, pid_start + avg_chunksize) + emb_end = emb_start + torch.sum(self.searcher.ranker.doclens[pid_start:pid_end]).item() + + # Write new chunk with id = curr_num_chunks + self._write_to_new_chunk(curr_num_chunks, pid_start, pid_end, emb_start, emb_end) + + curr_num_chunks += 1 + pid_start = pid_end + emb_start = emb_end + + assert pid_start == pid_last + assert emb_start == emb_last + + # Update metadata + print_message("#> Updating metadata for added passages...") + self.metadata["num_chunks"] = curr_num_chunks + self.metadata["num_embeddings"] = torch.sum(self.searcher.ranker.doclens).item() + metadata_path = os.path.join(self.index_path, "metadata.json") + with open(metadata_path, "w") as output_metadata: + ujson.dump(self.metadata, output_metadata) + + # Save current IVF to disk + optimized_ivf_path = os.path.join(self.index_path, "ivf.pid.pt") + torch.save((self.curr_ivf, self.curr_ivf_lengths), optimized_ivf_path) + print_message(f"#> Persisted updated IVF to {optimized_ivf_path}") + + # HELPER FUNCTIONS BELOW + + def _load_disk_ivf(self): + print_message(f"#> Loading IVF...") + + if os.path.exists(os.path.join(self.index_path, "ivf.pid.pt")): + ivf, ivf_lengths = torch.load( + os.path.join(self.index_path, "ivf.pid.pt"), map_location="cpu" + ) + else: + assert os.path.exists(os.path.join(self.index_path, "ivf.pt")) + ivf, ivf_lengths = torch.load( + os.path.join(self.index_path, "ivf.pt"), map_location="cpu" + ) + ivf, ivf_lengths = optimize_ivf(ivf, ivf_lengths, self.index_path) + + self.curr_ivf = ivf + self.curr_ivf_lengths = ivf_lengths + + def _load_metadata(self): + with open(os.path.join(self.index_path, "metadata.json")) as f: + self.metadata = ujson.load(f) + + def _load_chunk_doclens(self, chunk_idx): + doclens = [] + + print_message("#> Loading doclens...") + + with open(os.path.join(self.index_path, f"doclens.{chunk_idx}.json")) as f: + chunk_doclens = ujson.load(f) + doclens.extend(chunk_doclens) + + doclens = torch.tensor(doclens) + return doclens + + def _load_chunk_codes(self, chunk_idx): + codes_path = os.path.join(self.index_path, f"{chunk_idx}.codes.pt") + return torch.load(codes_path, map_location="cpu") + + def _load_chunk_residuals(self, chunk_idx): + residuals_path = os.path.join(self.index_path, f"{chunk_idx}.residuals.pt") + return torch.load(residuals_path, map_location="cpu") + + def _load_chunk_metadata(self, chunk_idx): + with open(os.path.join(self.index_path, f"{chunk_idx}.metadata.json")) as f: + chunk_metadata = ujson.load(f) + return chunk_metadata + + def _get_chunk_idx(self, pid): + for i in range(self.metadata["num_chunks"]): + chunk_metadata = self._load_chunk_metadata(i) + if ( + chunk_metadata["passage_offset"] <= pid + and chunk_metadata["passage_offset"] + chunk_metadata["num_passages"] > pid + ): + return i + raise ValueError("Passage ID out of range") + + def _remove_pid_from_ivf(self, pids): + # Helper function for IndexUpdater.remove() + + new_ivf = [] + new_ivf_lengths = [] + runner = 0 + pids = set(pids) + + # Construct mask of where pids to be removed appear in ivf + mask = torch.isin(self.curr_ivf, torch.tensor(list(pids))) + indices = mask.nonzero() + + # Calculate end-indices of each centroid section in ivf + section_end_indices = [] + c = 0 + for length in self.curr_ivf_lengths.tolist(): + c += length + section_end_indices.append(c) + + # Record the number of pids removed from each centroid section + removed_len = [0 for _ in range(len(section_end_indices))] + j = 0 + for ind in indices: + while ind >= section_end_indices[j]: + j += 1 + removed_len[j] += 1 + + # Update changes + new_ivf = torch.masked_select(self.curr_ivf, ~mask) + new_ivf_lengths = self.curr_ivf_lengths - torch.tensor(removed_len) + + new_ivf_tensor = StridedTensor(new_ivf, new_ivf_lengths, use_gpu=False) + assert new_ivf_tensor != self.searcher.ranker.ivf + self.searcher.ranker.ivf = new_ivf_tensor + + self.curr_ivf = new_ivf + self.curr_ivf_lengths = new_ivf_lengths + + def _build_passage_partitions(self, codes): + # Helper function for IndexUpdater.add() + # Return a list of ordered, unique centroid ids from codes of a passage + codes = codes.sort() + ivf, values = codes.indices, codes.values + partitions, ivf_lengths = values.unique_consecutive(return_counts=True) + return partitions, ivf_lengths + + def _add_pid_to_ivf(self, partitions, pid): + """ + Helper function for IndexUpdater.add() + + Input: + partitions: list(int), centroid ids of the passage + pid: int, passage id + Output: None + + Adds the pid of new passage into the ivf. + """ + new_ivf = [] + new_ivf_lengths = [] + old_ivf = self.curr_ivf.tolist() + old_ivf_lengths = self.curr_ivf_lengths.tolist() + + partitions_runner = 0 + ivf_runner = 0 + for i in range(len(old_ivf_lengths)): + # First copy existing partition pids to new ivf + new_ivf.extend(old_ivf[ivf_runner : ivf_runner + old_ivf_lengths[i]]) + new_ivf_lengths.append(old_ivf_lengths[i]) + ivf_runner += old_ivf_lengths[i] + + # Add pid if partition_index i is in the passage's partitions + if partitions_runner < len(partitions) and i == partitions[partitions_runner]: + new_ivf.append(pid) + new_ivf_lengths[-1] += 1 + partitions_runner += 1 + + assert ivf_runner == len(old_ivf) + assert sum(new_ivf_lengths) == len(new_ivf) + + # Replace the current ivf with new_ivf + self.curr_ivf = torch.tensor(new_ivf) + self.curr_ivf_lengths = torch.tensor(new_ivf_lengths) + + def _write_to_last_chunk(self, pid_start, pid_end, emb_start, emb_end): + # Helper function for IndexUpdater.persist_to_disk() + + print_message(f"#> Writing {pid_end - pid_start} passages to the last chunk...") + num_chunks = self.metadata["num_chunks"] + + # Append to current last chunk + curr_embs = ResidualEmbeddings.load(self.index_path, num_chunks - 1) + curr_embs.codes = torch.cat( + (curr_embs.codes, self.searcher.ranker.embeddings.codes[emb_start:emb_end]) + ) + curr_embs.residuals = torch.cat( + ( + curr_embs.residuals, + self.searcher.ranker.embeddings.residuals[emb_start:emb_end], + ) + ) + path_prefix = os.path.join(self.index_path, f"{num_chunks - 1}") + curr_embs.save(path_prefix) + + # Update doclen of last chunk + curr_doclens = self._load_chunk_doclens(num_chunks - 1).tolist() + curr_doclens.extend(self.searcher.ranker.doclens.tolist()[pid_start:pid_end]) + doclens_path = os.path.join(self.index_path, f"doclens.{num_chunks - 1}.json") + with open(doclens_path, "w") as output_doclens: + ujson.dump(curr_doclens, output_doclens) + + # Update metadata of last chunk + chunk_metadata = self._load_chunk_metadata(num_chunks - 1) + chunk_metadata["num_passages"] += pid_end - pid_start + chunk_metadata["num_embeddings"] += emb_end - emb_start + chunk_metadata_path = os.path.join(self.index_path, f"{num_chunks - 1}.metadata.json") + with open(chunk_metadata_path, "w") as output_chunk_metadata: + ujson.dump(chunk_metadata, output_chunk_metadata) + + def _write_to_new_chunk(self, chunk_idx, pid_start, pid_end, emb_start, emb_end): + # Helper function for IndexUpdater.persist_to_disk() + + # Save embeddings to new chunk + curr_embs = ResidualEmbeddings( + self.searcher.ranker.embeddings.codes[emb_start:emb_end], + self.searcher.ranker.embeddings.residuals[emb_start:emb_end], + ) + path_prefix = os.path.join(self.index_path, f"{chunk_idx}") + curr_embs.save(path_prefix) + + # Create doclen json file for new chunk + curr_doclens = self.searcher.ranker.doclens.tolist()[pid_start:pid_end] + doclens_path = os.path.join(self.index_path, f"doclens.{chunk_idx}.json") + with open(doclens_path, "w+") as output_doclens: + ujson.dump(curr_doclens, output_doclens) + + # Create metadata json file for new chunk + chunk_metadata = { + "passage_offset": pid_start, + "num_passages": pid_end - pid_start, + "embedding_offset": emb_start, + "num_embeddings": emb_end - emb_start, + } + chunk_metadata_path = os.path.join(self.index_path, f"{chunk_idx}.metadata.json") + with open(chunk_metadata_path, "w+") as output_chunk_metadata: + ujson.dump(chunk_metadata, output_chunk_metadata) + + def _remove_passage_from_disk(self, pid): + # Helper function for IndexUpdater.persist_to_disk() + + chunk_idx = self._get_chunk_idx(pid) + + chunk_metadata = self._load_chunk_metadata(chunk_idx) + i = pid - chunk_metadata["passage_offset"] + doclens = self._load_chunk_doclens(chunk_idx) + codes, residuals = ( + self._load_chunk_codes(chunk_idx), + self._load_chunk_residuals(chunk_idx), + ) + + # Remove embeddings from codes and residuals + start = sum(doclens[:i]) + end = start + doclens[i] + codes = torch.cat((codes[:start], codes[end:])) + residuals = torch.cat((residuals[:start], residuals[end:])) + + codes_path = os.path.join(self.index_path, f"{chunk_idx}.codes.pt") + residuals_path = os.path.join(self.index_path, f"{chunk_idx}.residuals.pt") + + torch.save(codes, codes_path) + torch.save(residuals, residuals_path) + + # Change doclen for passage to 0 + doclens = doclens.tolist() + doclen_to_remove = doclens[i] + doclens[i] = 0 + doclens_path = os.path.join(self.index_path, f"doclens.{chunk_idx}.json") + with open(doclens_path, "w") as output_doclens: + ujson.dump(doclens, output_doclens) + + # Modify chunk_metadata['num_embeddings'] for chunk_idx + chunk_metadata["num_embeddings"] -= doclen_to_remove + chunk_metadata_path = os.path.join(self.index_path, f"{chunk_idx}.metadata.json") + with open(chunk_metadata_path, "w") as output_chunk_metadata: + ujson.dump(chunk_metadata, output_chunk_metadata) + + # Modify chunk_metadata['embedding_offset'] for all later chunks (minus num_embs_removed) + for idx in range(chunk_idx + 1, self.metadata["num_chunks"]): + metadata = self._load_chunk_metadata(idx) + metadata["embedding_offset"] -= doclen_to_remove + metadata_path = os.path.join(self.index_path, f"{idx}.metadata.json") + with open(metadata_path, "w") as output_chunk_metadata: + ujson.dump(metadata, output_chunk_metadata) + + # Modify num_embeddings in overall metadata (minus num_embs_removed) + self.metadata["num_embeddings"] -= doclen_to_remove + metadata_path = os.path.join(self.index_path, "metadata.json") + with open(metadata_path, "w") as output_metadata: + ujson.dump(self.metadata, output_metadata) diff --git a/libs/colbert/colbert/indexer.py b/libs/colbert/colbert/indexer.py index 5e95ec3..4f884d4 100644 --- a/libs/colbert/colbert/indexer.py +++ b/libs/colbert/colbert/indexer.py @@ -83,5 +83,6 @@ def __launch(self, collection): shared_lists = [manager.list() for _ in range(self.config.nranks)] shared_queues = [manager.Queue(maxsize=1) for _ in range(self.config.nranks)] + # Encodes collection into index using the CollectionIndexer class launcher = Launcher(encode) launcher.launch(self.config, collection, shared_lists, shared_queues) diff --git a/libs/colbert/colbert/indexing/collection_indexer.py b/libs/colbert/colbert/indexing/collection_indexer.py index dc0079e..782028b 100644 --- a/libs/colbert/colbert/indexing/collection_indexer.py +++ b/libs/colbert/colbert/indexing/collection_indexer.py @@ -32,6 +32,11 @@ def encode(config, collection, shared_lists, shared_queues): class CollectionIndexer: + """ + Given a collection and config, encode collection into index and + stores the index on the disk in chunks. + """ + def __init__(self, config: ColBERTConfig, collection): self.config = config self.rank, self.nranks = self.config.rank, self.config.nranks @@ -53,24 +58,30 @@ def __init__(self, config: ColBERTConfig, collection): def run(self, shared_lists): with torch.inference_mode(): - self.setup() + self.setup() # Computes and saves plan for whole collection distributed.barrier(self.rank) print_memory_stats(f"RANK:{self.rank}") if not self.config.resume or not self.saver.try_load_codec(): - self.train(shared_lists) + self.train(shared_lists) # Trains centroids from selected passages distributed.barrier(self.rank) print_memory_stats(f"RANK:{self.rank}") - self.index() + self.index() # Encodes and saves all tokens into residuals distributed.barrier(self.rank) print_memory_stats(f"RANK:{self.rank}") - self.finalize() + self.finalize() # Builds metadata and centroid to passage mapping distributed.barrier(self.rank) print_memory_stats(f"RANK:{self.rank}") def setup(self): + """ + Calculates and saves plan.json for the whole collection. + + plan.json { config, num_chunks, num_partitions, num_embeddings_est, avg_doclen_est} + num_partitions is the number of centroids to be generated. + """ if self.config.resume: if self._try_load_plan(): Run().print_main(f"#> Loaded plan from {self.plan_path}:") @@ -82,6 +93,7 @@ def setup(self): self.num_chunks = int(np.ceil(len(self.collection) / self.collection.get_chunksize())) + # Saves sampled passages and embeddings for training k-means centroids later sampled_pids = self._sample_pids() avg_doclen_est = self._sample_embeddings(sampled_pids) @@ -223,6 +235,7 @@ def train(self, shared_lists): print_message(f"avg_residual = {avg_residual}") + # Compute and save codec into avg_residual.pt, buckets.pt and centroids.pt codec = ResidualCodec( config=self.config, centroids=centroids, @@ -269,20 +282,27 @@ def _train_kmeans(self, sample, shared_lists): if self.use_gpu: torch.cuda.empty_cache() + do_fork_for_faiss = False # set to True to free faiss GPU-0 memory at the cost of one more copy of `sample`. + args_ = [self.config.dim, self.num_partitions, self.config.kmeans_niters] - # shared_lists[0][0] = sample - # return_value_queue = mp.Queue() + if do_fork_for_faiss: + # For this to work reliably, write the sample to disk. Pickle may not handle >4GB of data. + # Delete the sample file after work is done. - # args_ = args_ + [shared_lists, return_value_queue] - # proc = mp.Process(target=compute_faiss_kmeans, args=args_) + shared_lists[0][0] = sample + return_value_queue = mp.Queue() - # proc.start() - # centroids = return_value_queue.get() - # proc.join() + args_ = args_ + [shared_lists, return_value_queue] + proc = mp.Process(target=compute_faiss_kmeans, args=args_) - args_ = args_ + [[[sample]]] - centroids = compute_faiss_kmeans(*args_) + proc.start() + centroids = return_value_queue.get() + proc.join() + + else: + args_ = args_ + [[[sample]]] + centroids = compute_faiss_kmeans(*args_) centroids = torch.nn.functional.normalize(centroids, dim=-1) if self.use_gpu: @@ -335,6 +355,16 @@ def _compute_avg_residual(self, centroids, heldout): # sample_avg_residual = (sample - sample_reconstruct).mean(dim=0) def index(self): + """ + Encode embeddings for all passages in collection. + Each embedding is converted to code (centroid id) and residual. + Embeddings stored according to passage order in contiguous chunks of memory. + + Saved data files described below: + {CHUNK#}.codes.pt: centroid id for each embedding in chunk + {CHUNK#}.residuals.pt: 16-bits residual for each embedding in chunk + doclens.{CHUNK#}.pt: number of embeddings within each passage in chunk + """ with self.saver.thread(): batches = self.collection.enumerate_batches(rank=self.rank) for chunk_idx, offset, passages in tqdm.tqdm(batches, disable=self.rank > 0): @@ -343,6 +373,7 @@ def index(self): f"#> Found chunk {chunk_idx} in the index already, skipping encoding..." ) continue + # Encode passages into embeddings with the checkpoint model embs, doclens = self.encoder.encode_passages(passages) if self.use_gpu: assert embs.dtype == torch.float16 @@ -355,10 +386,23 @@ def index(self): f"and {embs.size(0):,} embeddings. From #{offset:,} onward." ) - self.saver.save_chunk(chunk_idx, offset, embs, doclens) + self.saver.save_chunk( + chunk_idx, offset, embs, doclens + ) # offset = first passage index in chunk del embs, doclens def finalize(self): + """ + Aggregates and stores metadata for each chunk and the whole index + Builds and saves inverse mapping from centroids to passage IDs + + Saved data files described below: + {CHUNK#}.metadata.json: [ passage_offset, num_passages, num_embeddings, embedding_offset ] + metadata.json: [ num_chunks, num_partitions, num_embeddings, avg_doclen ] + inv.pid.pt: [ ivf, ivf_lengths ] + ivf is an array of passage IDs for centroids 0, 1, ... + ivf_length contains the number of passage IDs for each centroid + """ if self.rank > 0: return @@ -453,6 +497,7 @@ def _build_ivf(self): print_memory_stats(f"RANK:{self.rank}") + # Transforms centroid->embedding ivf to centroid->passage ivf _, _ = optimize_ivf(ivf, ivf_lengths, self.config.index_path_) def _update_metadata(self): diff --git a/libs/colbert/colbert/indexing/index_manager.py b/libs/colbert/colbert/indexing/index_manager.py new file mode 100644 index 0000000..fd69bbd --- /dev/null +++ b/libs/colbert/colbert/indexing/index_manager.py @@ -0,0 +1,39 @@ +import numpy as np +import torch +from bitarray import bitarray + + +class IndexManager: + def __init__(self, dim): + self.dim = dim + + def save(self, tensor, path_prefix): + torch.save(tensor, path_prefix) + + def save_bitarray(self, bitarray, path_prefix): + with open(path_prefix, "wb") as f: + bitarray.tofile(f) + + +def load_index_part(filename, verbose=True): + part = torch.load(filename) + + if type(part) == list: # for backward compatibility + part = torch.cat(part) + + return part + + +def load_compressed_index_part(filename, dim, bits): + a = bitarray() + + with open(filename, "rb") as f: + a.fromfile(f) + + n = len(a) // dim // bits + part = torch.tensor( + np.frombuffer(a.tobytes(), dtype=np.uint8) + ) # TODO: isn't from_numpy(.) faster? + part = part.reshape((n, int(np.ceil(dim * bits / 8)))) + + return part diff --git a/libs/colbert/colbert/infra/config/core_config.py b/libs/colbert/colbert/infra/config/core_config.py index 4e8fc33..5ee64ee 100644 --- a/libs/colbert/colbert/infra/config/core_config.py +++ b/libs/colbert/colbert/infra/config/core_config.py @@ -14,6 +14,12 @@ class DefaultVal: val: Any + def __hash__(self): + return hash(repr(self.val)) + + def __eq__(self, other): + self.val == other.val + @dataclass class CoreConfig: diff --git a/libs/colbert/colbert/infra/config/settings.py b/libs/colbert/colbert/infra/config/settings.py index a0e9172..7e5ebbe 100644 --- a/libs/colbert/colbert/infra/config/settings.py +++ b/libs/colbert/colbert/infra/config/settings.py @@ -150,7 +150,7 @@ class TrainingSettings: ignore_scores: bool = DefaultVal(False) - model_name: str = DefaultVal("bert-base-uncased") + model_name: str = DefaultVal(None) # DefaultVal('bert-base-uncased') @dataclass diff --git a/libs/colbert/colbert/modeling/base_colbert.py b/libs/colbert/colbert/modeling/base_colbert.py index 4cfde21..9f68a60 100644 --- a/libs/colbert/colbert/modeling/base_colbert.py +++ b/libs/colbert/colbert/modeling/base_colbert.py @@ -1,8 +1,9 @@ import os +import sys import torch from colbert.infra.config import ColBERTConfig -from colbert.modeling.hf_colbert import HF_ColBERT +from colbert.modeling.hf_colbert import class_factory from colbert.utils.utils import torch_load_dnn from transformers import AutoTokenizer @@ -15,15 +16,17 @@ class BaseColBERT(torch.nn.Module): Like HF, evaluation mode is the default. """ - def __init__(self, name, colbert_config=None): + def __init__(self, name_or_path, colbert_config=None): super().__init__() - self.name = name self.colbert_config = ColBERTConfig.from_existing( - ColBERTConfig.load_from_checkpoint(name), colbert_config + ColBERTConfig.load_from_checkpoint(name_or_path), colbert_config ) - self.model = HF_ColBERT.from_pretrained(name, colbert_config=self.colbert_config) - self.raw_tokenizer = AutoTokenizer.from_pretrained(self.model.base) + self.name = self.colbert_config.model_name or name_or_path + assert self.name is not None + HF_ColBERT = class_factory(self.name) + self.model = HF_ColBERT.from_pretrained(name_or_path, colbert_config=self.colbert_config) + self.raw_tokenizer = AutoTokenizer.from_pretrained(name_or_path) self.eval() @@ -33,7 +36,7 @@ def device(self): @property def bert(self): - return self.model.bert + return self.model.LM @property def linear(self): diff --git a/libs/colbert/colbert/modeling/colbert.py b/libs/colbert/colbert/modeling/colbert.py index cbdadf2..f574da4 100644 --- a/libs/colbert/colbert/modeling/colbert.py +++ b/libs/colbert/colbert/modeling/colbert.py @@ -103,7 +103,6 @@ def doc(self, input_ids, attention_mask, keep_dims=True): input_ids, attention_mask = input_ids.to(self.device), attention_mask.to(self.device) D = self.bert(input_ids, attention_mask=attention_mask)[0] D = self.linear(D) - mask = ( torch.tensor(self.mask(input_ids, skiplist=self.skiplist), device=self.device) .unsqueeze(2) @@ -126,7 +125,6 @@ def doc(self, input_ids, attention_mask, keep_dims=True): def score(self, Q, D_padded, D_mask): # assert self.colbert_config.similarity == 'cosine' - if self.colbert_config.similarity == "l2": assert self.colbert_config.interaction == "colbert" return ( @@ -134,7 +132,6 @@ def score(self, Q, D_padded, D_mask): .max(-1) .values.sum(-1) ) - return colbert_score(Q, D_padded, D_mask, config=self.colbert_config) def mask(self, input_ids, skiplist): diff --git a/libs/colbert/colbert/modeling/hf_colbert.py b/libs/colbert/colbert/modeling/hf_colbert.py index 617bdaa..b99227f 100644 --- a/libs/colbert/colbert/modeling/hf_colbert.py +++ b/libs/colbert/colbert/modeling/hf_colbert.py @@ -1,71 +1,159 @@ +import importlib +from unicodedata import name + import torch.nn as nn +import transformers from colbert.utils.utils import torch_load_dnn -from transformers import AutoTokenizer, BertModel, BertPreTrainedModel - - -class HF_ColBERT(BertPreTrainedModel): +from transformers import ( + AutoConfig, + AutoModel, + AutoTokenizer, + BertModel, + BertPreTrainedModel, + DebertaV2Model, + DebertaV2PreTrainedModel, + ElectraModel, + ElectraPreTrainedModel, + RobertaModel, + RobertaPreTrainedModel, + XLMRobertaConfig, + XLMRobertaModel, +) + + +class XLMRobertaPreTrainedModel(RobertaPreTrainedModel): """ - Shallow wrapper around HuggingFace transformers. All new parameters should be defined at this level. - - This makes sure `{from,save}_pretrained` and `init_weights` are applied to new parameters correctly. + This class overrides [`RobertaModel`]. Please check the superclass for the appropriate documentation alongside + usage examples. """ - _keys_to_ignore_on_load_unexpected = [r"cls"] - - def __init__(self, config, colbert_config): - super().__init__(config) - - self.dim = colbert_config.dim - self.bert = BertModel(config) - self.linear = nn.Linear(config.hidden_size, colbert_config.dim, bias=False) - - # if colbert_config.relu: - # self.score_scaler = nn.Linear(1, 1) - - self.init_weights() - - # if colbert_config.relu: - # self.score_scaler.weight.data.fill_(1.0) - # self.score_scaler.bias.data.fill_(-8.0) - - @classmethod - def from_pretrained(cls, name_or_path, colbert_config): - if name_or_path.endswith(".dnn"): - dnn = torch_load_dnn(name_or_path) - base = dnn.get("arguments", {}).get("model", "bert-base-uncased") - - obj = super().from_pretrained( - base, state_dict=dnn["model_state_dict"], colbert_config=colbert_config - ) - obj.base = base + config_class = XLMRobertaConfig + + +base_class_mapping = { + "roberta-base": RobertaPreTrainedModel, + "google/electra-base-discriminator": ElectraPreTrainedModel, + "xlm-roberta-base": XLMRobertaPreTrainedModel, + "xlm-roberta-large": XLMRobertaPreTrainedModel, + "bert-base-uncased": BertPreTrainedModel, + "bert-large-uncased": BertPreTrainedModel, + "microsoft/mdeberta-v3-base": DebertaV2PreTrainedModel, + "bert-base-multilingual-uncased": BertPreTrainedModel, +} + +model_object_mapping = { + "roberta-base": RobertaModel, + "google/electra-base-discriminator": ElectraModel, + "xlm-roberta-base": XLMRobertaModel, + "xlm-roberta-large": XLMRobertaModel, + "bert-base-uncased": BertModel, + "bert-large-uncased": BertModel, + "microsoft/mdeberta-v3-base": DebertaV2Model, + "bert-base-multilingual-uncased": BertModel, +} + + +transformers_module = dir(transformers) + + +def find_class_names(model_type, class_type): + model_type = model_type.replace("-", "").lower() + for item in transformers_module: + if model_type + class_type == item.lower(): + return item + + return None + + +def class_factory(name_or_path): + loadedConfig = AutoConfig.from_pretrained(name_or_path) + model_type = loadedConfig.model_type + pretrained_class = find_class_names(model_type, "pretrainedmodel") + model_class = find_class_names(model_type, "model") + + if pretrained_class is not None: + pretrained_class_object = getattr(transformers, pretrained_class) + elif model_type == "xlm-roberta": + pretrained_class_object = XLMRobertaPreTrainedModel + elif base_class_mapping.get(name_or_path) is not None: + pretrained_class_object = base_class_mapping.get(name_or_path) + else: + raise ValueError( + "Could not find correct pretrained class for the model type {model_type} in transformers library" + ) + + if model_class != None: + model_class_object = getattr(transformers, model_class) + elif model_object_mapping.get(name_or_path) is not None: + model_class_object = model_object_mapping.get(name_or_path) + else: + raise ValueError( + "Could not find correct model class for the model type {model_type} in transformers library" + ) + + class HF_ColBERT(pretrained_class_object): + """ + Shallow wrapper around HuggingFace transformers. All new parameters should be defined at this level. + + This makes sure `{from,save}_pretrained` and `init_weights` are applied to new parameters correctly. + """ + + _keys_to_ignore_on_load_unexpected = [r"cls"] + + def __init__(self, config, colbert_config): + super().__init__(config) + + self.config = config + self.dim = colbert_config.dim + self.linear = nn.Linear(config.hidden_size, colbert_config.dim, bias=False) + setattr(self, self.base_model_prefix, model_class_object(config)) + + # if colbert_config.relu: + # self.score_scaler = nn.Linear(1, 1) + + self.init_weights() + + # if colbert_config.relu: + # self.score_scaler.weight.data.fill_(1.0) + # self.score_scaler.bias.data.fill_(-8.0) + + @property + def LM(self): + base_model_prefix = getattr(self, "base_model_prefix") + return getattr(self, base_model_prefix) + + @classmethod + def from_pretrained(cls, name_or_path, colbert_config): + if name_or_path.endswith(".dnn"): + dnn = torch_load_dnn(name_or_path) + base = dnn.get("arguments", {}).get("model", "bert-base-uncased") + + obj = super().from_pretrained( + base, state_dict=dnn["model_state_dict"], colbert_config=colbert_config + ) + obj.base = base + + return obj + + obj = super().from_pretrained(name_or_path, colbert_config=colbert_config) + obj.base = name_or_path return obj - obj = super().from_pretrained(name_or_path, colbert_config=colbert_config) - obj.base = name_or_path + @staticmethod + def raw_tokenizer_from_pretrained(name_or_path): + if name_or_path.endswith(".dnn"): + dnn = torch_load_dnn(name_or_path) + base = dnn.get("arguments", {}).get("model", "bert-base-uncased") - return obj + obj = AutoTokenizer.from_pretrained(base) + obj.base = base - @staticmethod - def raw_tokenizer_from_pretrained(name_or_path): - if name_or_path.endswith(".dnn"): - dnn = torch_load_dnn(name_or_path) - base = dnn.get("arguments", {}).get("model", "bert-base-uncased") + return obj - obj = AutoTokenizer.from_pretrained(base) - obj.base = base + obj = AutoTokenizer.from_pretrained(name_or_path) + obj.base = name_or_path return obj - obj = AutoTokenizer.from_pretrained(name_or_path) - obj.base = name_or_path - - return obj - - -""" -TODO: It's easy to write a class generator that takes "name_or_path" and loads AutoConfig to check the Architecture's - name, finds that name's *PreTrainedModel and *Model in dir(transformers), and then basically repeats the above. - - It's easy for the BaseColBERT class to instantiate things from there. -""" + return HF_ColBERT diff --git a/libs/colbert/colbert/modeling/tokenization/doc_tokenization.py b/libs/colbert/colbert/modeling/tokenization/doc_tokenization.py index b851aff..0662da6 100644 --- a/libs/colbert/colbert/modeling/tokenization/doc_tokenization.py +++ b/libs/colbert/colbert/modeling/tokenization/doc_tokenization.py @@ -1,6 +1,6 @@ import torch from colbert.infra import ColBERTConfig -from colbert.modeling.hf_colbert import HF_ColBERT +from colbert.modeling.hf_colbert import class_factory from colbert.modeling.tokenization.utils import _sort_by_length, _split_into_batches # from transformers import BertTokenizerFast @@ -8,19 +8,19 @@ class DocTokenizer: def __init__(self, config: ColBERTConfig): + HF_ColBERT = class_factory(config.checkpoint) self.tok = HF_ColBERT.raw_tokenizer_from_pretrained(config.checkpoint) self.config = config self.doc_maxlen = config.doc_maxlen - self.D_marker_token, self.D_marker_token_id = "[D]", self.tok.convert_tokens_to_ids( - "[unused1]" - ) + ( + self.D_marker_token, + self.D_marker_token_id, + ) = self.config.doc_token, self.tok.convert_tokens_to_ids(self.config.doc_token_id) self.cls_token, self.cls_token_id = self.tok.cls_token, self.tok.cls_token_id self.sep_token, self.sep_token_id = self.tok.sep_token, self.tok.sep_token_id - assert self.D_marker_token_id == 2 - def tokenize(self, batch_text, add_special_tokens=False): assert type(batch_text) in [list, tuple], type(batch_text) diff --git a/libs/colbert/colbert/modeling/tokenization/query_tokenization.py b/libs/colbert/colbert/modeling/tokenization/query_tokenization.py index ab15ba6..fcbc679 100644 --- a/libs/colbert/colbert/modeling/tokenization/query_tokenization.py +++ b/libs/colbert/colbert/modeling/tokenization/query_tokenization.py @@ -1,26 +1,27 @@ import torch from colbert.infra import ColBERTConfig -from colbert.modeling.hf_colbert import HF_ColBERT +from colbert.modeling.hf_colbert import class_factory from colbert.modeling.tokenization.utils import _split_into_batches from colbert.utils.utils import batch class QueryTokenizer: def __init__(self, config: ColBERTConfig): + HF_ColBERT = class_factory(config.checkpoint) self.tok = HF_ColBERT.raw_tokenizer_from_pretrained(config.checkpoint) self.config = config self.query_maxlen = config.query_maxlen self.background_maxlen = 512 - self.query_maxlen + 1 # FIXME: Make this configurable - self.Q_marker_token, self.Q_marker_token_id = "[Q]", self.tok.convert_tokens_to_ids( - "[unused0]" - ) + ( + self.Q_marker_token, + self.Q_marker_token_id, + ) = config.query_token, self.tok.convert_tokens_to_ids(config.query_token_id) self.cls_token, self.cls_token_id = self.tok.cls_token, self.tok.cls_token_id self.sep_token, self.sep_token_id = self.tok.sep_token, self.tok.sep_token_id self.mask_token, self.mask_token_id = self.tok.mask_token, self.tok.mask_token_id - - assert self.Q_marker_token_id == 1 and self.mask_token_id == 103 + self.pad_token, self.pad_token_id = self.tok.pad_token, self.tok.pad_token_id self.used = False def tokenize(self, batch_text, add_special_tokens=False): @@ -73,7 +74,7 @@ def tensorize(self, batch_text, bsize=None, context=None): # postprocess for the [Q] marker and the [MASK] augmentation ids[:, 1] = self.Q_marker_token_id - ids[ids == 0] = self.mask_token_id + ids[ids == self.pad_token_id] = self.mask_token_id if context is not None: assert len(context) == len(batch_text), (len(context), len(batch_text)) diff --git a/libs/colbert/colbert/search/index_storage.py b/libs/colbert/colbert/search/index_storage.py index 43b84da..a0cf344 100644 --- a/libs/colbert/colbert/search/index_storage.py +++ b/libs/colbert/colbert/search/index_storage.py @@ -81,6 +81,8 @@ def rank(self, config, Q, filter_fn=None): if filter_fn is not None: pids = filter_fn(pids) + if len(pids) == 0: + return [], [] scores, pids = self.score_pids(config, Q, pids, centroid_scores) diff --git a/libs/colbert/colbert/searcher.py b/libs/colbert/colbert/searcher.py index 162bd85..927ea6e 100644 --- a/libs/colbert/colbert/searcher.py +++ b/libs/colbert/colbert/searcher.py @@ -31,8 +31,8 @@ def __init__(self, index, checkpoint=None, collection=None, config=None): self.checkpoint_config, self.index_config, initial_config ) - # self.collection = Collection.cast(collection or self.config.collection) - self.configure(checkpoint=self.checkpoint) + self.collection = Collection.cast(collection or self.config.collection) + self.configure(checkpoint=self.checkpoint, collection=self.collection) self.checkpoint = Checkpoint(self.checkpoint, colbert_config=self.config) use_gpu = self.config.total_visible_gpus > 0 diff --git a/libs/colbert/conda_env.yml b/libs/colbert/conda_env.yml index 294bd29..1b9184c 100644 --- a/libs/colbert/conda_env.yml +++ b/libs/colbert/conda_env.yml @@ -26,3 +26,5 @@ dependencies: - tqdm - transformers - ujson + - flask + - python-dotenv diff --git a/libs/colbert/conda_env_cpu.yml b/libs/colbert/conda_env_cpu.yml index 1e4702f..28002ed 100644 --- a/libs/colbert/conda_env_cpu.yml +++ b/libs/colbert/conda_env_cpu.yml @@ -18,3 +18,5 @@ dependencies: - tqdm - transformers - ujson + - flask + - python-dotenv diff --git a/libs/colbert/server.py b/libs/colbert/server.py new file mode 100644 index 0000000..6d2c953 --- /dev/null +++ b/libs/colbert/server.py @@ -0,0 +1,51 @@ +import math +import os +from functools import lru_cache + +from colbert import Searcher +from colbert.infra import ColBERTConfig, Run, RunConfig +from dotenv import load_dotenv +from flask import Flask, render_template, request + +load_dotenv() + +INDEX_NAME = os.getenv("INDEX_NAME") +INDEX_ROOT = os.getenv("INDEX_ROOT") +app = Flask(__name__) + +searcher = Searcher(index=f"{INDEX_ROOT}/{INDEX_NAME}") +counter = {"api": 0} + + +@lru_cache(maxsize=1000000) +def api_search_query(query, k): + print(f"Query={query}") + if k == None: + k = 10 + k = min(int(k), 100) + pids, ranks, scores = searcher.search(query, k=100) + pids, ranks, scores = pids[:k], ranks[:k], scores[:k] + passages = [searcher.collection[pid] for pid in pids] + probs = [math.exp(score) for score in scores] + probs = [prob / sum(probs) for prob in probs] + topk = [] + for pid, rank, score, prob in zip(pids, ranks, scores, probs): + text = searcher.collection[pid] + d = {"text": text, "pid": pid, "rank": rank, "score": score, "prob": prob} + topk.append(d) + topk = list(sorted(topk, key=lambda p: (-1 * p["score"], p["pid"]))) + return {"query": query, "topk": topk} + + +@app.route("/api/search", methods=["GET"]) +def api_search(): + if request.method == "GET": + counter["api"] += 1 + print("API request count:", counter["api"]) + return api_search_query(request.args.get("query"), request.args.get("k")) + else: + return ("", 405) + + +if __name__ == "__main__": + app.run("0.0.0.0", int(os.getenv("PORT"))) diff --git a/models.md b/models.md index 527120d..44f2702 100644 --- a/models.md +++ b/models.md @@ -42,9 +42,23 @@ Hub](https://huggingface.co/Intel/ColBERT-NQ) for more details. of all the documents. Index can be created by the user given a collection and a checkpoint, or can be specified via a path. -> :warning: PLAID Requirements :warning: +**Updated:** new feature that enables adding and removing documents from a given index. Example usage: + +```python +index_updater = IndexUpdater(config, searcher, checkpoint) + +added_pids = index_updater.add(passages) # Adding passages +index_updater.remove(pids) # Removing passages +searcher.search() # Search now reflects the added & removed passages + +index_updater.persist_to_disk() # Persist changes to disk +``` + +--- + +#### :warning: PLAID Requirements :warning: > -> If GPU is needed it should be of type RTX 3090 or newer and PyTorch should be installed with CUDA support using: +> If GPU is needed it should be of type RTX 3090 or newer (Ampere) and PyTorch should be installed with CUDA support using: > >```bash >pip install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu116