-
Notifications
You must be signed in to change notification settings - Fork 0
/
kg_env.py
91 lines (65 loc) · 2.67 KB
/
kg_env.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
from graph_embeddings.models.factory import EmbeddingModelFactory
from graph_embeddings.models.embedding_model import EmbeddingModel
from graph_embeddings.data_loader import DataLoader
from typing import Dict
import torch
import numpy as np
import os
import re
def generate_entity_embeddings(kge_model: EmbeddingModel, entities_dict: Dict, batch_size=5000):
embeddings = {}
entity_idx_list = list(entities_dict.values())
kge_model.eval()
i = 0
while i < len(entity_idx_list):
batch = entity_idx_list[i:i + batch_size]
with torch.no_grad():
emb_list = kge_model.E(torch.Tensor(batch).long()).cpu().numpy()
for idx, emb in zip(batch, emb_list):
embeddings[idx] = emb
i += batch_size
return embeddings
def generate_relation_embeddings(kge_model: EmbeddingModel, relations_dict: Dict):
embeddings = {}
rel_idx_list = list(relations_dict.values())
kge_model.eval()
with torch.no_grad():
emb_list = kge_model.E(torch.Tensor(rel_idx_list).long()).cpu().numpy()
for idx, emb in zip(rel_idx_list, emb_list):
embeddings[idx] = emb
return embeddings
def load_kge_model(dataset_name, model_name, ent_vec_dim, rel_vec_dim, loss_type, device, path, do_batch_norm=True,
reverse_rel=True, **kwargs):
torch.backends.cudnn.deterministic = True
seed = 20
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
print("operating on gpu")
torch.cuda.manual_seed_all(seed)
else:
print("operating on cpu")
data_loader = DataLoader(dataset=dataset_name, reverse_rel=reverse_rel)
embedding_generator = EmbeddingModelFactory(model_name).create(
data_loader, ent_vec_dim, rel_vec_dim, loss_type, device, do_batch_norm, **kwargs
)
checkpoint = torch.load(os.path.join(data_loader.base_data_dir, path), map_location=torch.device(device))
embedding_generator.load_state_dict(checkpoint)
return embedding_generator
def extract_question_entity_target(raw_questions):
all_questions = []
all_entities = []
all_targets = []
for raw_q in raw_questions:
question, targets = raw_q.split('\t')
print(question, targets)
entity = re.findall('\[.*?\]', question)[0] \
.replace('[', '') \
.replace(']', '')
# todo: should I replace entity with some special token?
question = question.replace(']', '').replace('[', '')
targets = targets.strip().split('|')
all_questions.append(question)
all_targets.append(targets)
all_entities.append(entity)
return all_questions, all_entities, all_targets