Skip to content

Commit

Permalink
style(predict): update parse_opt()
Browse files Browse the repository at this point in the history
  • Loading branch information
zjykzj committed Oct 2, 2023
1 parent de34b0d commit 9ae19f9
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 6 deletions.
8 changes: 5 additions & 3 deletions predict_rpnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,10 @@

def parse_opt():
parser = argparse.ArgumentParser(description='Predict RPNet')
parser.add_argument('image', metavar='IMAGE', type=str, default="./assets/1.jpg",
help='path to image')
parser.add_argument('rpnet', metavar='RPNet', type=str, default="./runs/RPNet-e60.pth",
help='path to pretrained path')
parser.add_argument('image', metavar='IMAGE', type=str, default="./assets/1.jpg",
help='path to image')

args = parser.parse_args()
print(f"args: {args}")
Expand All @@ -49,7 +49,9 @@ def parse_opt():
# rpnet_pretrained = "runs/RPNet-e60.pth"
rpnet_pretrained = args.rpnet
print(f"Loading RPNet pretrained: {rpnet_pretrained}")
model.load_state_dict(torch.load(rpnet_pretrained, map_location='cpu'))
ckpt = torch.load(rpnet_pretrained, map_location='cpu')
ckpt = {k.replace("module.", ""): v for k, v in ckpt.items()}
model.load_state_dict(ckpt, strict=True)
model = model.to(device)
model.eval()

Expand Down
8 changes: 5 additions & 3 deletions predict_wr2.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,10 @@

def parse_opt():
parser = argparse.ArgumentParser(description='Predict wR2')
parser.add_argument('image', metavar='IMAGE', type=str, default="./assets/1.jpg",
help='path to image')
parser.add_argument('wr2', metavar='wR2', type=str, default="./runs/wR2-e45.pth",
help='path to pretrained path')
parser.add_argument('image', metavar='IMAGE', type=str, default="./assets/1.jpg",
help='path to image')

args = parser.parse_args()
print(f"args: {args}")
Expand All @@ -44,7 +44,9 @@ def parse_opt():
# wr2_pretrained = "runs/wR2-e45.pth"
wr2_pretrained = args.wr2
print(f"Loading wR2 pretrained: {wr2_pretrained}")
model.load_state_dict(torch.load(wr2_pretrained, map_location='cpu'))
ckpt = torch.load(wr2_pretrained, map_location='cpu')
ckpt = {k.replace("module.", ""): v for k, v in ckpt.items()}
model.load_state_dict(ckpt, strict=True)
model.eval()

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
Expand Down

0 comments on commit 9ae19f9

Please sign in to comment.