We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
1 parent 3f4b30d commit b8c440aCopy full SHA for b8c440a
pretrained_models/create_pykeen_model.py
@@ -24,8 +24,8 @@ def main():
24
train_factory = TriplesFactory.from_path_binary(triples_factory_path)
25
26
print(f'[X] Loading entity and relation embeddings from {entity_embeddings_path} and {relation_embeddings_path}')
27
- entity_embeddings = torch.from_numpy(np.load(entity_embeddings_path))
28
- relation_embeddings = torch.from_numpy(np.load(relation_embeddings_path))
+ entity_embeddings = torch.from_numpy(np.load(entity_embeddings_path)).type(torch.complex64)
+ relation_embeddings = torch.from_numpy(np.load(relation_embeddings_path)).type(torch.complex64)
29
30
print(f'[X] Creating PyKEEN model from trained {model_name} embeddings')
31
model = ModelClass(
0 commit comments