|
| 1 | +import os |
| 2 | +import torch |
| 3 | +import numpy as np |
| 4 | +from FlagEmbedding import BGEM3FlagModel |
| 5 | + |
| 6 | + |
| 7 | +def pad_colbert_vecs(colbert_vecs_list, device): |
| 8 | + lengths = [vec.shape[0] for vec in colbert_vecs_list] |
| 9 | + max_len = max(lengths) |
| 10 | + dim = colbert_vecs_list[0].shape[1] |
| 11 | + |
| 12 | + padded_tensor = torch.zeros(len(colbert_vecs_list), max_len, dim, dtype=torch.float, device=device) |
| 13 | + for i, vec in enumerate(colbert_vecs_list): |
| 14 | + length = vec.shape[0] |
| 15 | + padded_tensor[i, :length, :] = torch.tensor(vec, dtype=torch.float, device=device) |
| 16 | + |
| 17 | + return padded_tensor |
| 18 | + |
| 19 | + |
| 20 | +def compute_colbert_scores(query_colbert_vecs, passage_colbert_vecs): |
| 21 | + # query_colbert_vecs: (Q, Tq, D) |
| 22 | + # passage_colbert_vecs: (P, Tp, D) |
| 23 | + # einsum 식에서 q:queries, p:passages, r:query tokens dim, c:passage tokens dim, d:embedding dim |
| 24 | + dot_products = torch.einsum("qrd,pcd->qprc", query_colbert_vecs, passage_colbert_vecs) # Q,P,Tq,Tp |
| 25 | + max_per_query_token, _ = dot_products.max(dim=3) # max over c (Tp) |
| 26 | + colbert_scores = max_per_query_token.sum(dim=2) # sum over r (Tq) |
| 27 | + return colbert_scores |
| 28 | + |
| 29 | + |
| 30 | +def hybrid_dbfs_ensemble(dense_scores, sparse_scores, colbert_scores, weights=(0.33, 0.33, 0.34)): |
| 31 | + w_dense, w_sparse, w_colbert = weights |
| 32 | + # 모든 입력이 torch.Tensor일 경우 아래 연산 정상 작동 |
| 33 | + return w_dense * dense_scores + w_sparse * sparse_scores + w_colbert * colbert_scores |
| 34 | + |
| 35 | + |
| 36 | +def test_m3_single_device(): |
| 37 | + model = BGEM3FlagModel( |
| 38 | + 'BAAI/bge-m3', |
| 39 | + devices="cuda:0", |
| 40 | + pooling_method='cls', |
| 41 | + cache_dir=os.getenv('HF_HUB_CACHE', None), |
| 42 | + ) |
| 43 | + |
| 44 | + queries = [ |
| 45 | + "What is BGE M3?", |
| 46 | + "Defination of BM25" |
| 47 | + ] * 100 |
| 48 | + passages = [ |
| 49 | + "BGE M3 is an embedding model supporting dense retrieval, lexical matching and multi-vector interaction.", |
| 50 | + "BM25 is a bag-of-words retrieval function that ranks a set of documents based on the query terms appearing in each document" |
| 51 | + ] * 100 |
| 52 | + |
| 53 | + queries_embeddings = model.encode_queries( |
| 54 | + queries, |
| 55 | + return_dense=True, |
| 56 | + return_sparse=True, |
| 57 | + return_colbert_vecs=True, |
| 58 | + ) |
| 59 | + passages_embeddings = model.encode_corpus( |
| 60 | + passages, |
| 61 | + return_dense=True, |
| 62 | + return_sparse=True, |
| 63 | + return_colbert_vecs=True, |
| 64 | + ) |
| 65 | + |
| 66 | + # device 설정 |
| 67 | + device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") |
| 68 | + |
| 69 | + # dense_vecs, lexical_weights 등이 numpy array 형태일 수 있으므로 텐서로 변환 |
| 70 | + q_dense = torch.tensor(queries_embeddings["dense_vecs"], dtype=torch.float, device=device) |
| 71 | + p_dense = torch.tensor(passages_embeddings["dense_vecs"], dtype=torch.float, device=device) |
| 72 | + dense_scores = q_dense @ p_dense.T |
| 73 | + |
| 74 | + # sparse_scores도 numpy array를 텐서로 변환 |
| 75 | + sparse_scores_np = model.compute_lexical_matching_score( |
| 76 | + queries_embeddings["lexical_weights"], |
| 77 | + passages_embeddings["lexical_weights"] |
| 78 | + ) |
| 79 | + sparse_scores = torch.tensor(sparse_scores_np, dtype=torch.float, device=device) |
| 80 | + |
| 81 | + # colbert_vecs 패딩 후 텐서 변환 |
| 82 | + query_colbert_vecs = pad_colbert_vecs(queries_embeddings["colbert_vecs"], device) |
| 83 | + passage_colbert_vecs = pad_colbert_vecs(passages_embeddings["colbert_vecs"], device) |
| 84 | + |
| 85 | + colbert_scores = compute_colbert_scores(query_colbert_vecs, passage_colbert_vecs) |
| 86 | + |
| 87 | + # 모든 스코어가 torch.Tensor이므로 오류 없이 연산 가능 |
| 88 | + hybrid_scores = hybrid_dbfs_ensemble(dense_scores, sparse_scores, colbert_scores) |
| 89 | + |
| 90 | + print("Dense score:\n", dense_scores[:2, :2]) |
| 91 | + print("Sparse score:\n", sparse_scores[:2, :2]) |
| 92 | + print("ColBERT score:\n", colbert_scores[:2, :2]) |
| 93 | + print("Hybrid DBSF Ensemble score:\n", hybrid_scores[:2, :2]) |
| 94 | + |
| 95 | + |
| 96 | +if __name__ == '__main__': |
| 97 | + test_m3_single_device() |
| 98 | + print("--------------------------------") |
| 99 | + print("Expected Output for Dense & Sparse (original):") |
| 100 | + print("Dense score:") |
| 101 | + print(" [[0.626 0.3477]\n [0.3496 0.678 ]]") |
| 102 | + print("Sparse score:") |
| 103 | + print(" [[0.19554901 0.00880432]\n [0. 0.18036556]]") |
| 104 | + print("--------------------------------") |
| 105 | + print("ColBERT and Hybrid DBSF scores will vary depending on the actual embeddings.") |
0 commit comments