5
5
6
6
from colbert .modeling .checkpoint import Checkpoint # noqa: E402
7
7
from colbert .modeling .colbert import ColBERTConfig , colbert_score # noqa: E402
8
-
9
8
from pylate import models , rank
10
9
11
10
from lightning_ir import BiEncoderModule # noqa: E402
@@ -55,16 +54,16 @@ def test_same_as_modern_colbert():
55
54
doc_embedding = output .doc_embeddings
56
55
57
56
orig_model = models .ColBERT (model_name_or_path = model_name )
58
- orig_query = orig_model .encode (
59
- [query ],
60
- is_query = True ,
61
- )
62
- orig_docs = orig_model .encode (
63
- [documents ],
64
- is_query = False ,
57
+ orig_query = orig_model .encode ([query ], is_query = True )
58
+ orig_docs = orig_model .encode ([documents ], is_query = False )
59
+ orig_scores = rank .rerank (
60
+ queries_embeddings = orig_query , documents_embeddings = orig_docs , documents_ids = [list (range (len (documents )))]
65
61
)
66
- orig_scores = rank .rerank (queries_embeddings = orig_query , documents_embeddings = orig_docs , documents_ids = [list (range (len (documents )))])
67
62
68
63
assert torch .allclose (query_embedding .embeddings , torch .tensor (orig_query [0 ]), atol = 1e-6 )
69
- assert torch .allclose (doc_embedding .embeddings [doc_embedding .scoring_mask ], torch .cat ([torch .from_numpy (d ) for doc in orig_docs for d in doc ]), atol = 1e-6 )
70
- assert torch .allclose (output .scores , torch .tensor ([d ["score" ] for q in orig_scores for d in q ]), atol = 1e-6 )
64
+ assert torch .allclose (
65
+ doc_embedding .embeddings [doc_embedding .scoring_mask ],
66
+ torch .cat ([torch .from_numpy (d ) for doc in orig_docs for d in doc ]),
67
+ atol = 1e-6 ,
68
+ )
69
+ assert torch .allclose (output .scores , torch .tensor ([d ["score" ] for q in orig_scores for d in q ]), atol = 1e-6 )
0 commit comments