Skip to content

Commit 4de7cbf

Browse files
author
Morten Terhart
committed
Add project files to repository
1 parent 6dcc59a commit 4de7cbf

21 files changed

+191604
-0
lines changed

.gitignore

+177
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,177 @@
1+
### Python template
2+
# Byte-compiled / optimized / DLL files
3+
__pycache__/
4+
*.py[cod]
5+
*$py.class
6+
7+
# C extensions
8+
*.so
9+
10+
# Distribution / packaging
11+
.Python
12+
build/
13+
develop-eggs/
14+
dist/
15+
downloads/
16+
eggs/
17+
.eggs/
18+
lib/
19+
lib64/
20+
parts/
21+
sdist/
22+
var/
23+
wheels/
24+
share/python-wheels/
25+
*.egg-info/
26+
.installed.cfg
27+
*.egg
28+
MANIFEST
29+
30+
# PyInstaller
31+
# Usually these files are written by a python script from a template
32+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
33+
*.manifest
34+
*.spec
35+
36+
# Installer logs
37+
pip-log.txt
38+
pip-delete-this-directory.txt
39+
40+
# Unit test / coverage reports
41+
htmlcov/
42+
.tox/
43+
.nox/
44+
.coverage
45+
.coverage.*
46+
.cache
47+
nosetests.xml
48+
coverage.xml
49+
*.cover
50+
*.py,cover
51+
.hypothesis/
52+
.pytest_cache/
53+
cover/
54+
55+
# Translations
56+
*.mo
57+
*.pot
58+
59+
# Django stuff:
60+
*.log
61+
local_settings.py
62+
db.sqlite3
63+
db.sqlite3-journal
64+
65+
# Flask stuff:
66+
instance/
67+
.webassets-cache
68+
69+
# Scrapy stuff:
70+
.scrapy
71+
72+
# Sphinx documentation
73+
docs/_build/
74+
75+
# PyBuilder
76+
.pybuilder/
77+
target/
78+
79+
# Jupyter Notebook
80+
.ipynb_checkpoints
81+
82+
# IPython
83+
profile_default/
84+
ipython_config.py
85+
86+
# pyenv
87+
# For a library or package, you might want to ignore these files since the code is
88+
# intended to run in multiple environments; otherwise, check them in:
89+
# .python-version
90+
91+
# pipenv
92+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
93+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
94+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
95+
# install all needed dependencies.
96+
#Pipfile.lock
97+
98+
# poetry
99+
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
100+
# This is especially recommended for binary packages to ensure reproducibility, and is more
101+
# commonly ignored for libraries.
102+
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
103+
#poetry.lock
104+
105+
# pdm
106+
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
107+
#pdm.lock
108+
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
109+
# in version control.
110+
# https://pdm.fming.dev/#use-with-ide
111+
.pdm.toml
112+
113+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
114+
__pypackages__/
115+
116+
# Celery stuff
117+
celerybeat-schedule
118+
celerybeat.pid
119+
120+
# SageMath parsed files
121+
*.sage.py
122+
123+
# Environments
124+
.env
125+
.venv
126+
env/
127+
venv/
128+
ENV/
129+
env.bak/
130+
venv.bak/
131+
132+
# Spyder project settings
133+
.spyderproject
134+
.spyproject
135+
136+
# Rope project settings
137+
.ropeproject
138+
139+
# mkdocs documentation
140+
/site
141+
142+
# mypy
143+
.mypy_cache/
144+
.dmypy.json
145+
dmypy.json
146+
147+
# Pyre type checker
148+
.pyre/
149+
150+
# pytype static type analyzer
151+
.pytype/
152+
153+
# Cython debug symbols
154+
cython_debug/
155+
156+
# PyCharm
157+
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
158+
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
159+
# and can be added to the global gitignore or merged into this file. For a more nuclear
160+
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
161+
#.idea/
162+
163+
164+
### Project files
165+
166+
# Wikidata5M dataset
167+
dataset/aliases
168+
dataset/documents
169+
dataset/knowledge_graph
170+
171+
# Trained embeddings and mappings
172+
embeddings/*/trained_model.pkl
173+
embeddings/*/training_triples/
174+
pretrained_embeddings/
175+
176+
# JetBrains files
177+
.idea/

compute_predicate_metrics.py

+117
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,117 @@
1+
import os
2+
from timeit import default_timer as timer
3+
from datetime import timedelta
4+
5+
import numpy as np
6+
import pandas as pd
7+
import torch
8+
from pykeen.evaluation import RankBasedEvaluator, RankBasedMetricResults
9+
from pykeen.triples import TriplesFactory
10+
from pykeen.datasets import get_dataset
11+
from pykeen.evaluation.rank_based_evaluator import _iter_ranks
12+
13+
14+
def main():
15+
wikidata5m_test_set = pd.read_csv('dataset/knowledge_graph/wikidata5m_transductive_test.txt', sep="\t",
16+
names=['S', 'P', 'O'], header=None)
17+
trained_models = get_trained_models()
18+
print(
19+
f'[X] Loaded {len(trained_models)} trained models and {get_number_of_predicates(wikidata5m_test_set)} test '
20+
f'splits per predicate')
21+
22+
wikidata5m_dataset = get_dataset(dataset='Wikidata5M')
23+
24+
print(
25+
f'[X] Loaded {wikidata5m_dataset.get_normalized_name()} dataset with {wikidata5m_dataset.training.num_triples} '
26+
f'training, {wikidata5m_dataset.validation.num_triples} validation and {wikidata5m_dataset.testing.num_triples} '
27+
'test triples')
28+
29+
print(f'[X] Starting evaluation on models')
30+
start = timer()
31+
predicate_metrics = evaluate_models_per_predicate(trained_models, wikidata5m_test_set, wikidata5m_dataset)
32+
33+
print(f'[X] Finished evaluation in {timedelta(seconds=timer() - start)}')
34+
35+
predicate_metrics.to_csv('metrics/predicate_metrics.csv')
36+
37+
38+
def get_number_of_predicates(dataset_df):
39+
return dataset_df['P'].nunique()
40+
41+
42+
def get_test_set_per_predicate(test_set_file):
43+
test_set = pd.read_csv(test_set_file, sep="\t", names=['S', 'P', 'O'], header=None)
44+
return test_set.groupby('P')
45+
46+
47+
def get_trained_models():
48+
return {
49+
'ComplEx': _load_trained_model('embeddings/ComplEx'),
50+
'DistMult': _load_trained_model('embeddings/DistMult'),
51+
'SimplE': _load_trained_model('embeddings/SimplE'),
52+
'TransE': _load_trained_model('embeddings/TransE')
53+
}
54+
55+
56+
def _load_trained_model(saved_model_dir):
57+
return {
58+
'model': torch.load(os.path.join(saved_model_dir, 'trained_model.pkl')),
59+
'factory': TriplesFactory.from_path_binary(os.path.join(saved_model_dir, 'training_triples'))
60+
}
61+
62+
63+
def evaluate_models_per_predicate(trained_models, wikidata5m_test_set, dataset):
64+
aggregated_metrics = pd.DataFrame()
65+
for model_name, result in trained_models.items():
66+
model = result['model']
67+
training_factory = result['factory']
68+
69+
test_factory = TriplesFactory.from_labeled_triples(
70+
triples=wikidata5m_test_set.values,
71+
entity_to_id=training_factory.entity_to_id,
72+
relation_to_id=training_factory.relation_to_id
73+
)
74+
75+
evaluator = RankBasedEvaluator(clear_on_finalize=False)
76+
evaluator.evaluate(
77+
model=model,
78+
mapped_triples=test_factory.mapped_triples,
79+
additional_filter_triples=[
80+
dataset.training.mapped_triples,
81+
dataset.validation.mapped_triples
82+
]
83+
)
84+
85+
ranks_df = test_factory.tensor_to_df(
86+
tensor=test_factory.mapped_triples,
87+
**{"-".join(("rank",) + key): np.concatenate(value) for key, value in evaluator.ranks.items()},
88+
**{"-".join(("num_candidates", key)): np.concatenate(value) for key, value in
89+
evaluator.num_candidates.items()},
90+
)
91+
92+
for (relation_id, relation_label), group in ranks_df.groupby(by=['relation_id', 'relation_label']):
93+
relation_ranks = {}
94+
relation_num_candidates = {}
95+
96+
for column in group.columns:
97+
if column.startswith('rank-'):
98+
relation_ranks[tuple(column.split('-')[1:])] = [group[column].values]
99+
elif column.startswith('num_candidates-'):
100+
relation_num_candidates[tuple(column.split('-'))[1]] = [group[column].values]
101+
102+
metric_results = RankBasedMetricResults.from_ranks(
103+
metrics=evaluator.metrics,
104+
rank_and_candidates=_iter_ranks(ranks=relation_ranks, num_candidates=relation_num_candidates)
105+
).to_df()
106+
107+
metric_results['relation_id'] = relation_id
108+
metric_results['relation_label'] = relation_label
109+
metric_results['model'] = model_name
110+
111+
aggregated_metrics = pd.concat([aggregated_metrics, metric_results], ignore_index=True)
112+
113+
return aggregated_metrics
114+
115+
116+
if __name__ == '__main__':
117+
main()

dataset/convert_csv_to_turtle.py

+24
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
import pandas as pd
2+
3+
wikidata_prefix = 'https://www.wikidata.org/wiki/'
4+
5+
6+
def main():
7+
df = pd.read_csv('./knowledge_graph/wikidata5m_transductive_train.tsv', sep='\t', names=['S', 'P', 'O'])
8+
9+
# Transform triples to Turtle format in the dataframe
10+
turtle_df = df.apply(row_to_turtle, axis=1)
11+
12+
turtle_file = './knowledge_graph/wikidata5m_transductive_train.ttl'
13+
with open(turtle_file, 'w') as f:
14+
f.write(f'@prefix wd: <{wikidata_prefix}>\n\n')
15+
16+
turtle_df.to_csv(turtle_file, mode='a', header=False, index=False)
17+
18+
19+
def row_to_turtle(row):
20+
return f'wd:{row["S"]} wd:{row["P"]} wd:{row["O"]} .'
21+
22+
23+
if __name__ == '__main__':
24+
main()

embeddings/ComplEx/metadata.json

+1
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
{}

0 commit comments

Comments
 (0)