Skip to content

Commit 630223c

Browse files
committed
chore: example for ensemble
1 parent bc09ef5 commit 630223c

File tree

1 file changed

+105
-0
lines changed

1 file changed

+105
-0
lines changed
Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,105 @@
1+
import os
2+
import torch
3+
import numpy as np
4+
from FlagEmbedding import BGEM3FlagModel
5+
6+
7+
def pad_colbert_vecs(colbert_vecs_list, device):
8+
lengths = [vec.shape[0] for vec in colbert_vecs_list]
9+
max_len = max(lengths)
10+
dim = colbert_vecs_list[0].shape[1]
11+
12+
padded_tensor = torch.zeros(len(colbert_vecs_list), max_len, dim, dtype=torch.float, device=device)
13+
for i, vec in enumerate(colbert_vecs_list):
14+
length = vec.shape[0]
15+
padded_tensor[i, :length, :] = torch.tensor(vec, dtype=torch.float, device=device)
16+
17+
return padded_tensor
18+
19+
20+
def compute_colbert_scores(query_colbert_vecs, passage_colbert_vecs):
21+
# query_colbert_vecs: (Q, Tq, D)
22+
# passage_colbert_vecs: (P, Tp, D)
23+
# einsum 식에서 q:queries, p:passages, r:query tokens dim, c:passage tokens dim, d:embedding dim
24+
dot_products = torch.einsum("qrd,pcd->qprc", query_colbert_vecs, passage_colbert_vecs) # Q,P,Tq,Tp
25+
max_per_query_token, _ = dot_products.max(dim=3) # max over c (Tp)
26+
colbert_scores = max_per_query_token.sum(dim=2) # sum over r (Tq)
27+
return colbert_scores
28+
29+
30+
def hybrid_dbfs_ensemble(dense_scores, sparse_scores, colbert_scores, weights=(0.33, 0.33, 0.34)):
31+
w_dense, w_sparse, w_colbert = weights
32+
# 모든 입력이 torch.Tensor일 경우 아래 연산 정상 작동
33+
return w_dense * dense_scores + w_sparse * sparse_scores + w_colbert * colbert_scores
34+
35+
36+
def test_m3_single_device():
37+
model = BGEM3FlagModel(
38+
'BAAI/bge-m3',
39+
devices="cuda:0",
40+
pooling_method='cls',
41+
cache_dir=os.getenv('HF_HUB_CACHE', None),
42+
)
43+
44+
queries = [
45+
"What is BGE M3?",
46+
"Defination of BM25"
47+
] * 100
48+
passages = [
49+
"BGE M3 is an embedding model supporting dense retrieval, lexical matching and multi-vector interaction.",
50+
"BM25 is a bag-of-words retrieval function that ranks a set of documents based on the query terms appearing in each document"
51+
] * 100
52+
53+
queries_embeddings = model.encode_queries(
54+
queries,
55+
return_dense=True,
56+
return_sparse=True,
57+
return_colbert_vecs=True,
58+
)
59+
passages_embeddings = model.encode_corpus(
60+
passages,
61+
return_dense=True,
62+
return_sparse=True,
63+
return_colbert_vecs=True,
64+
)
65+
66+
# device 설정
67+
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
68+
69+
# dense_vecs, lexical_weights 등이 numpy array 형태일 수 있으므로 텐서로 변환
70+
q_dense = torch.tensor(queries_embeddings["dense_vecs"], dtype=torch.float, device=device)
71+
p_dense = torch.tensor(passages_embeddings["dense_vecs"], dtype=torch.float, device=device)
72+
dense_scores = q_dense @ p_dense.T
73+
74+
# sparse_scores도 numpy array를 텐서로 변환
75+
sparse_scores_np = model.compute_lexical_matching_score(
76+
queries_embeddings["lexical_weights"],
77+
passages_embeddings["lexical_weights"]
78+
)
79+
sparse_scores = torch.tensor(sparse_scores_np, dtype=torch.float, device=device)
80+
81+
# colbert_vecs 패딩 후 텐서 변환
82+
query_colbert_vecs = pad_colbert_vecs(queries_embeddings["colbert_vecs"], device)
83+
passage_colbert_vecs = pad_colbert_vecs(passages_embeddings["colbert_vecs"], device)
84+
85+
colbert_scores = compute_colbert_scores(query_colbert_vecs, passage_colbert_vecs)
86+
87+
# 모든 스코어가 torch.Tensor이므로 오류 없이 연산 가능
88+
hybrid_scores = hybrid_dbfs_ensemble(dense_scores, sparse_scores, colbert_scores)
89+
90+
print("Dense score:\n", dense_scores[:2, :2])
91+
print("Sparse score:\n", sparse_scores[:2, :2])
92+
print("ColBERT score:\n", colbert_scores[:2, :2])
93+
print("Hybrid DBSF Ensemble score:\n", hybrid_scores[:2, :2])
94+
95+
96+
if __name__ == '__main__':
97+
test_m3_single_device()
98+
print("--------------------------------")
99+
print("Expected Output for Dense & Sparse (original):")
100+
print("Dense score:")
101+
print(" [[0.626 0.3477]\n [0.3496 0.678 ]]")
102+
print("Sparse score:")
103+
print(" [[0.19554901 0.00880432]\n [0. 0.18036556]]")
104+
print("--------------------------------")
105+
print("ColBERT and Hybrid DBSF scores will vary depending on the actual embeddings.")

0 commit comments

Comments
 (0)