Skip to content

Commit 6856677

Browse files
support model loading on m1
1 parent 93d1637 commit 6856677

File tree

1 file changed

+10
-5
lines changed

1 file changed

+10
-5
lines changed

Demo.py

+10-5
Original file line numberDiff line numberDiff line change
@@ -202,7 +202,12 @@ def main():
202202

203203
args = parser.parse_args()
204204
use_cuda = torch.cuda.is_available()
205-
device = torch.device(args.deviceID if use_cuda else "cpu")
205+
if use_cuda:
206+
device = torch.device(args.deviceID)
207+
else: # attempt upgrade to Metal acceleration
208+
use_mps = torch.backends.mps.is_built()
209+
device = torch.device("mps" if use_mps else "cpu")
210+
206211
resultModelFile = "pretrained_model"
207212

208213

@@ -213,11 +218,11 @@ def main():
213218
model.cuda()
214219
if os.path.isfile(resultModelFile):
215220
try:
216-
model.load_state_dict(torch.load(resultModelFile))
217-
except:
218-
print("Cannot load the saved model")
221+
model.load_state_dict(torch.load(resultModelFile, map_location=device))
222+
except RuntimeError as e:
223+
raise RuntimeError(f"Cannot load the saved model:\n{e}")
219224

220225
demo(model, device, args.image)
221226

222227
if __name__ == "__main__":
223-
main()
228+
main()

0 commit comments

Comments
 (0)