@@ -13,8 +13,8 @@ def main():
13
13
wikidata5m_valid = load_wikidata5m_dataset ('valid' )
14
14
wikidata5m_test = load_wikidata5m_dataset ('test' )
15
15
16
- model = torch .load ('embeddings/ComplEx /trained_model.pkl' )
17
- model_factory = TriplesFactory .from_path_binary ('embeddings/ComplEx /training_triples' )
16
+ model = torch .load ('embeddings/dim_32/complex /trained_model.pkl' )
17
+ model_factory = TriplesFactory .from_path_binary ('embeddings/dim_32/complex /training_triples' )
18
18
print ('Training set stats:' )
19
19
print (f' Num Triples: { model_factory .num_triples } ' )
20
20
print (f' Num Entities: { model_factory .num_entities } ' )
@@ -32,11 +32,11 @@ def main():
32
32
print (preds_df [preds_df ['in_testing' ] == True ])
33
33
34
34
predicate_label = wikidata5m_test ['P' ].iloc [10 ]
35
- arithmetic_mean_rank = get_predicate_metric (predicate_metrics , 'arithmetic_mean_rank' , predicate_label , 'ComplEx ' , 'tail' , 'realistic' )
36
- hits_at_1 = get_predicate_metric (predicate_metrics , 'hits_at_1' , predicate_label , 'ComplEx ' , 'tail' , 'realistic' )
37
- hits_at_3 = get_predicate_metric (predicate_metrics , 'hits_at_3' , predicate_label , 'ComplEx ' , 'tail' , 'realistic' )
38
- hits_at_5 = get_predicate_metric (predicate_metrics , 'hits_at_5' , predicate_label , 'ComplEx ' , 'tail' , 'realistic' )
39
- hits_at_10 = get_predicate_metric (predicate_metrics , 'hits_at_10' , predicate_label , 'ComplEx ' , 'tail' , 'realistic' )
35
+ arithmetic_mean_rank = get_predicate_metric (predicate_metrics , 'arithmetic_mean_rank' , predicate_label , 'complex ' , 'tail' , 'realistic' )
36
+ hits_at_1 = get_predicate_metric (predicate_metrics , 'hits_at_1' , predicate_label , 'complex ' , 'tail' , 'realistic' )
37
+ hits_at_3 = get_predicate_metric (predicate_metrics , 'hits_at_3' , predicate_label , 'complex ' , 'tail' , 'realistic' )
38
+ hits_at_5 = get_predicate_metric (predicate_metrics , 'hits_at_5' , predicate_label , 'complex ' , 'tail' , 'realistic' )
39
+ hits_at_10 = get_predicate_metric (predicate_metrics , 'hits_at_10' , predicate_label , 'complex ' , 'tail' , 'realistic' )
40
40
print (f'Arithmetic mean rank: { arithmetic_mean_rank } ' )
41
41
print (f'Hits at 1: { hits_at_1 } ' )
42
42
print (f'Hits at 3: { hits_at_3 } ' )
@@ -45,7 +45,7 @@ def main():
45
45
46
46
47
47
def load_wikidata5m_dataset (subset_type : SubsetType ):
48
- return pd .read_csv (f'dataset/knowledge_graph /wikidata5m_transductive_{ subset_type } .txt' , sep = '\t ' ,
48
+ return pd .read_csv (f'dataset/wikidata5m /wikidata5m_transductive_{ subset_type } .txt' , sep = '\t ' ,
49
49
names = ['S' , 'P' , 'O' ])
50
50
51
51
0 commit comments