Skip to content

Commit e2b32ba

Browse files
author
Morten Terhart
committed
Separate dim32 and dim512 embedding results into different folders and update paths
1 parent 15b72b6 commit e2b32ba

23 files changed

+189038
-190098
lines changed

complete_triple_patterns.py

+8-8
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,8 @@ def main():
1313
wikidata5m_valid = load_wikidata5m_dataset('valid')
1414
wikidata5m_test = load_wikidata5m_dataset('test')
1515

16-
model = torch.load('embeddings/ComplEx/trained_model.pkl')
17-
model_factory = TriplesFactory.from_path_binary('embeddings/ComplEx/training_triples')
16+
model = torch.load('embeddings/dim_32/complex/trained_model.pkl')
17+
model_factory = TriplesFactory.from_path_binary('embeddings/dim_32/complex/training_triples')
1818
print('Training set stats:')
1919
print(f' Num Triples: {model_factory.num_triples}')
2020
print(f' Num Entities: {model_factory.num_entities}')
@@ -32,11 +32,11 @@ def main():
3232
print(preds_df[preds_df['in_testing'] == True])
3333

3434
predicate_label = wikidata5m_test['P'].iloc[10]
35-
arithmetic_mean_rank = get_predicate_metric(predicate_metrics, 'arithmetic_mean_rank', predicate_label, 'ComplEx', 'tail', 'realistic')
36-
hits_at_1 = get_predicate_metric(predicate_metrics, 'hits_at_1', predicate_label, 'ComplEx', 'tail', 'realistic')
37-
hits_at_3 = get_predicate_metric(predicate_metrics, 'hits_at_3', predicate_label, 'ComplEx', 'tail', 'realistic')
38-
hits_at_5 = get_predicate_metric(predicate_metrics, 'hits_at_5', predicate_label, 'ComplEx', 'tail', 'realistic')
39-
hits_at_10 = get_predicate_metric(predicate_metrics, 'hits_at_10', predicate_label, 'ComplEx', 'tail', 'realistic')
35+
arithmetic_mean_rank = get_predicate_metric(predicate_metrics, 'arithmetic_mean_rank', predicate_label, 'complex', 'tail', 'realistic')
36+
hits_at_1 = get_predicate_metric(predicate_metrics, 'hits_at_1', predicate_label, 'complex', 'tail', 'realistic')
37+
hits_at_3 = get_predicate_metric(predicate_metrics, 'hits_at_3', predicate_label, 'complex', 'tail', 'realistic')
38+
hits_at_5 = get_predicate_metric(predicate_metrics, 'hits_at_5', predicate_label, 'complex', 'tail', 'realistic')
39+
hits_at_10 = get_predicate_metric(predicate_metrics, 'hits_at_10', predicate_label, 'complex', 'tail', 'realistic')
4040
print(f'Arithmetic mean rank: {arithmetic_mean_rank}')
4141
print(f'Hits at 1: {hits_at_1}')
4242
print(f'Hits at 3: {hits_at_3}')
@@ -45,7 +45,7 @@ def main():
4545

4646

4747
def load_wikidata5m_dataset(subset_type: SubsetType):
48-
return pd.read_csv(f'dataset/knowledge_graph/wikidata5m_transductive_{subset_type}.txt', sep='\t',
48+
return pd.read_csv(f'dataset/wikidata5m/wikidata5m_transductive_{subset_type}.txt', sep='\t',
4949
names=['S', 'P', 'O'])
5050

5151

compute_model_predictions.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -15,13 +15,13 @@ def main():
1515
torch.cuda.empty_cache()
1616

1717
print(f'[X] Loading {model_name} model')
18-
train_factory = TriplesFactory.from_path_binary(f'embeddings/{model_name}/training_factory')
18+
train_factory = TriplesFactory.from_path_binary(f'embeddings/dim_512/{model_name}/training_factory')
1919

2020
model = TransE(
2121
triples_factory=train_factory,
2222
embedding_dim=512
2323
)
24-
model.load_state_dict(torch.load(f'embeddings/{model_name}/trained_model_state_dict.pt'))
24+
model.load_state_dict(torch.load(f'embeddings/dim_512/{model_name}/trained_model_state_dict.pt'))
2525
model.to(device).eval()
2626

2727
print(f'[X] Loading Wikidata5M datasets')
@@ -39,7 +39,7 @@ def main():
3939
scores_df = pack.process(factory=train_factory).df
4040

4141
print(f'[X] Saving predicted scores')
42-
scores_df.to_csv(f'embeddings/{model_name}/predicted_scores.csv')
42+
scores_df.to_csv(f'embeddings/dim_512/{model_name}/predicted_scores.csv', index=False)
4343

4444

4545
if __name__ == '__main__':

compute_predicate_metrics.py

+6-6
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212

1313

1414
def main():
15-
wikidata5m_test_set = pd.read_csv('dataset/knowledge_graph/wikidata5m_transductive_test.txt', sep="\t",
15+
wikidata5m_test_set = pd.read_csv('dataset/wikidata5m/wikidata5m_transductive_test.txt', sep="\t",
1616
names=['S', 'P', 'O'], header=None)
1717
trained_models = get_trained_models()
1818
print(
@@ -32,7 +32,7 @@ def main():
3232

3333
print(f'[X] Finished evaluation in {timedelta(seconds=timer() - start)}')
3434

35-
predicate_metrics.to_csv('metrics/predicate_metrics.csv')
35+
predicate_metrics.to_csv('metrics/predicate_metrics.csv', index=False)
3636

3737

3838
def get_number_of_predicates(dataset_df):
@@ -46,10 +46,10 @@ def get_test_set_per_predicate(test_set_file):
4646

4747
def get_trained_models():
4848
return {
49-
'ComplEx': _load_trained_model('embeddings/ComplEx'),
50-
'DistMult': _load_trained_model('embeddings/DistMult'),
51-
'SimplE': _load_trained_model('embeddings/SimplE'),
52-
'TransE': _load_trained_model('embeddings/TransE')
49+
'complex': _load_trained_model('embeddings/dim_32/complex'),
50+
'distmult': _load_trained_model('embeddings/dim_32/distmult'),
51+
'simple': _load_trained_model('embeddings/dim_32/simple'),
52+
'transe': _load_trained_model('embeddings/dim_32/transe')
5353
}
5454

5555

dataset/convert_csv_to_turtle.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,12 @@
44

55

66
def main():
7-
df = pd.read_csv('./knowledge_graph/wikidata5m_transductive_train.txt', sep='\t', names=['S', 'P', 'O'])
7+
df = pd.read_csv('./wikidata5m/wikidata5m_transductive_train.txt', sep='\t', names=['S', 'P', 'O'])
88

99
# Transform triples to Turtle format in the dataframe
1010
turtle_df = df.apply(row_to_turtle, axis=1)
1111

12-
turtle_file = './knowledge_graph/wikidata5m_transductive_train.ttl'
12+
turtle_file = './wikidata5m/wikidata5m_transductive_train.ttl'
1313
with open(turtle_file, 'w') as f:
1414
f.write(f'@prefix wd: <{wikidata_prefix}>\n\n')
1515

embeddings/ComplEx/metadata.json

-1
This file was deleted.

embeddings/ComplEx/results.json

-264
This file was deleted.

embeddings/DistMult/metadata.json

-1
This file was deleted.

0 commit comments

Comments
 (0)