Skip to content

Commit b8c440a

Browse files
committed
Convert embeddings to complex64 tensors for model conversion
1 parent 3f4b30d commit b8c440a

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

pretrained_models/create_pykeen_model.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,8 @@ def main():
2424
train_factory = TriplesFactory.from_path_binary(triples_factory_path)
2525

2626
print(f'[X] Loading entity and relation embeddings from {entity_embeddings_path} and {relation_embeddings_path}')
27-
entity_embeddings = torch.from_numpy(np.load(entity_embeddings_path))
28-
relation_embeddings = torch.from_numpy(np.load(relation_embeddings_path))
27+
entity_embeddings = torch.from_numpy(np.load(entity_embeddings_path)).type(torch.complex64)
28+
relation_embeddings = torch.from_numpy(np.load(relation_embeddings_path)).type(torch.complex64)
2929

3030
print(f'[X] Creating PyKEEN model from trained {model_name} embeddings')
3131
model = ModelClass(

0 commit comments

Comments
 (0)