File tree 1 file changed +10
-5
lines changed
1 file changed +10
-5
lines changed Original file line number Diff line number Diff line change @@ -202,7 +202,12 @@ def main():
202
202
203
203
args = parser .parse_args ()
204
204
use_cuda = torch .cuda .is_available ()
205
- device = torch .device (args .deviceID if use_cuda else "cpu" )
205
+ if use_cuda :
206
+ device = torch .device (args .deviceID )
207
+ else : # attempt upgrade to Metal acceleration
208
+ use_mps = torch .backends .mps .is_built ()
209
+ device = torch .device ("mps" if use_mps else "cpu" )
210
+
206
211
resultModelFile = "pretrained_model"
207
212
208
213
@@ -213,11 +218,11 @@ def main():
213
218
model .cuda ()
214
219
if os .path .isfile (resultModelFile ):
215
220
try :
216
- model .load_state_dict (torch .load (resultModelFile ))
217
- except :
218
- print ( "Cannot load the saved model" )
221
+ model .load_state_dict (torch .load (resultModelFile , map_location = device ))
222
+ except RuntimeError as e :
223
+ raise RuntimeError ( f "Cannot load the saved model: \n { e } " )
219
224
220
225
demo (model , device , args .image )
221
226
222
227
if __name__ == "__main__" :
223
- main ()
228
+ main ()
You can’t perform that action at this time.
0 commit comments