Skip to content

Commit 184b7a7

Browse files
authored
Merge pull request #61 from fschlatt/main
Fix CI + Minor pretty printing fix
2 parents 4c7e09b + 8730bfd commit 184b7a7

File tree

3 files changed

+15
-16
lines changed

3 files changed

+15
-16
lines changed

lightning_ir/base/module.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from lightning.pytorch.trainer.states import RunningStage
1515
from transformers import BatchEncoding
1616

17-
from ..data import LightningIRDataModule, RankBatch, RunDataset, SearchBatch, TrainBatch
17+
from ..data import RankBatch, RunDataset, SearchBatch, TrainBatch
1818
from ..data.dataset import IRDataset
1919
from ..loss.loss import InBatchLossFunction, LossFunction
2020
from .config import LightningIRConfig
@@ -414,10 +414,9 @@ def on_validation_end(self) -> None:
414414
df = df.pivot(index="dataset", columns="metric", values="value")
415415
df.columns.name = None
416416

417-
datamodule: LightningIRDataModule | None = getattr(trainer, "datamodule", None)
418-
if datamodule is not None and datamodule.inference_datasets is not None:
419-
dataset_ids = [dataset.dataset_id for dataset in datamodule.inference_datasets]
420-
df = df.reindex(dataset_ids)
417+
# bring into correct order when skipping inference datasets
418+
dataset_ids = [self.get_dataset_id(i) for i in range(df.shape[0])]
419+
df = df.reindex(dataset_ids)
421420

422421
trainer.print(df)
423422

setup.cfg

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ test =
4242
sentence-transformers
4343
faiss-cpu
4444
pyseismic-lsr
45+
pylate
4546
dev =
4647
black
4748
flake8

tests/test_models/test_col.py

Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55

66
from colbert.modeling.checkpoint import Checkpoint # noqa: E402
77
from colbert.modeling.colbert import ColBERTConfig, colbert_score # noqa: E402
8-
98
from pylate import models, rank
109

1110
from lightning_ir import BiEncoderModule # noqa: E402
@@ -55,16 +54,16 @@ def test_same_as_modern_colbert():
5554
doc_embedding = output.doc_embeddings
5655

5756
orig_model = models.ColBERT(model_name_or_path=model_name)
58-
orig_query = orig_model.encode(
59-
[query],
60-
is_query=True,
61-
)
62-
orig_docs = orig_model.encode(
63-
[documents],
64-
is_query=False,
57+
orig_query = orig_model.encode([query], is_query=True)
58+
orig_docs = orig_model.encode([documents], is_query=False)
59+
orig_scores = rank.rerank(
60+
queries_embeddings=orig_query, documents_embeddings=orig_docs, documents_ids=[list(range(len(documents)))]
6561
)
66-
orig_scores = rank.rerank(queries_embeddings=orig_query, documents_embeddings=orig_docs, documents_ids=[list(range(len(documents)))])
6762

6863
assert torch.allclose(query_embedding.embeddings, torch.tensor(orig_query[0]), atol=1e-6)
69-
assert torch.allclose(doc_embedding.embeddings[doc_embedding.scoring_mask], torch.cat([torch.from_numpy(d) for doc in orig_docs for d in doc]), atol=1e-6)
70-
assert torch.allclose(output.scores, torch.tensor([d["score"] for q in orig_scores for d in q]), atol=1e-6)
64+
assert torch.allclose(
65+
doc_embedding.embeddings[doc_embedding.scoring_mask],
66+
torch.cat([torch.from_numpy(d) for doc in orig_docs for d in doc]),
67+
atol=1e-6,
68+
)
69+
assert torch.allclose(output.scores, torch.tensor([d["score"] for q in orig_scores for d in q]), atol=1e-6)

0 commit comments

Comments
 (0)