Skip to content

Commit

Permalink
Merge pull request #65 from kaseris/fix/visualization-args
Browse files Browse the repository at this point in the history
Remove the hard coding of the model checkpoint
  • Loading branch information
kaseris committed Jan 9, 2024
2 parents 6723f1e + bb812d5 commit d9a5a54
Showing 1 changed file with 3 additions and 1 deletion.
4 changes: 3 additions & 1 deletion tools/visualize_skel_movement.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
argparser.add_argument('--dataset', type=str, required=True, help='Path to the dataset.')
argparser.add_argument('--sample', type=int, required=True, help='Sample index to visualize.')
argparser.add_argument('--cache-file', type=str, required=False, help='Path to the cache file.')
argparser.add_argument('--checkpoint', type=str, required=False, help='Path to the checkpoint file.',
default='/home/kaseris/Documents/mount/checkpoints_forecasting/presto-class/checkpoint_epoch_16_2024-01-05_092620.pt')

args = argparser.parse_args()

Expand All @@ -30,7 +32,7 @@
# checkpoint = torch.load('/home/kaseris/Documents/mount/checkpoints_forecasting/presto-class/checkpoint_epoch_16_2024-01-05_092620.pt')
# model_state_dict = checkpoint['model_state_dict']
# model.load_state_dict(model_state_dict)
model.from_pretrained('/home/kaseris/Documents/mount/checkpoints_forecasting/presto-class/checkpoint_epoch_16_2024-01-05_092620.pt')
model.from_pretrained(args.checkpoint)
model = model.to('cpu')
skeleton, label = dataset[args.sample]
seq_len, n_bodies, n_joints, n_dims = skeleton.shape
Expand Down

0 comments on commit d9a5a54

Please sign in to comment.