Skip to content

Commit 19f6270

Browse files
authored
Merge pull request #75 from instabaines/Illegal-memory-access-patch
Patch for illegal memory access error when using cuda:device!=0
2 parents e443a09 + 676827e commit 19f6270

File tree

1 file changed

+5
-1
lines changed

1 file changed

+5
-1
lines changed

easy_tpp/torch_wrapper.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,12 @@ def __init__(self, model, base_config, model_config, trainer_config):
2020
self.base_config = base_config
2121
self.model_config = model_config
2222
self.trainer_config = trainer_config
23-
23+
2424
self.model_id = self.base_config.model_id
25+
# Sometimes PyTorch may not switch the active device context for all operations
26+
# This causes illegal memory access error
27+
if self.trainer_config.gpu!=-1:
28+
torch.cuda.set_device(self.trainer_config.gpu)
2529
self.device = set_device(self.trainer_config.gpu)
2630

2731
self.model.to(self.device)

0 commit comments

Comments
 (0)