Skip to content

Commit

Permalink
FIX: allows cpu loading
Browse files Browse the repository at this point in the history
  • Loading branch information
ellisdg committed Jun 23, 2022
1 parent be5144d commit 5210046
Showing 1 changed file with 4 additions and 1 deletion.
5 changes: 4 additions & 1 deletion unet3d/models/pytorch/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,10 @@ def build_or_load_model(model_name, model_filename, n_features, n_outputs, n_gpu
elif n_gpus > 0:
model = model.cuda()
if os.path.exists(model_filename):
state_dict = torch.load(model_filename)
if n_gpus > 0:
state_dict = torch.load(model_filename)
else:
state_dict = torch.load(model_filename, map_location=torch.device('cpu'))
model = load_state_dict(model, state_dict, n_gpus=n_gpus, strict=strict)
return model

Expand Down

0 comments on commit 5210046

Please sign in to comment.