Skip to content

Commit

Permalink
Merge pull request #64 from kaseris/feature/from-pretrained
Browse files Browse the repository at this point in the history
from_pretrained interface for the skelcast modules
  • Loading branch information
kaseris committed Jan 9, 2024
2 parents 07ad810 + 12a1552 commit 6723f1e
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 4 deletions.
6 changes: 6 additions & 0 deletions src/skelcast/models/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,12 @@ def validation_step(self, *args, **kwargs):
Implements a validation step of a module
"""
pass

def from_pretrained(model_path=None):
"""
Implements a method to load a pretrained model
"""
pass

def gradient_flow(self):
"""
Expand Down
10 changes: 9 additions & 1 deletion src/skelcast/models/transformers/sttf.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,4 +257,12 @@ def predict(self, sample, n_steps=10, observe_from_to=[10]):
to_ += 1
sample_input = sample[:, from_:to_, ...]

return torch.stack(forecasted, dim=1)
return torch.stack(forecasted, dim=1)

def from_pretrained(self, model_path=None):
if model_path is None:
raise ValueError('`model_path` must be provided.')
checkpoint = torch.load(model_path)
model_state_dict = checkpoint['model_state_dict']
self.load_state_dict(model_state_dict)
return self
7 changes: 4 additions & 3 deletions tools/visualize_skel_movement.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,10 @@
max_number_of_bodies=1, transforms=tf)
model = SpatioTemporalTransformer(n_joints=25, d_model=256, n_blocks=3, n_heads=8, d_head=16, mlp_dim=512, loss_fn=nn.SmoothL1Loss(), dropout=0.5)
# TODO: Remove the hard coding of the checkpoint path
checkpoint = torch.load('/home/kaseris/Documents/mount/checkpoints_forecasting/presto-class/checkpoint_epoch_16_2024-01-05_092620.pt')
model_state_dict = checkpoint['model_state_dict']
model.load_state_dict(model_state_dict)
# checkpoint = torch.load('/home/kaseris/Documents/mount/checkpoints_forecasting/presto-class/checkpoint_epoch_16_2024-01-05_092620.pt')
# model_state_dict = checkpoint['model_state_dict']
# model.load_state_dict(model_state_dict)
model.from_pretrained('/home/kaseris/Documents/mount/checkpoints_forecasting/presto-class/checkpoint_epoch_16_2024-01-05_092620.pt')
model = model.to('cpu')
skeleton, label = dataset[args.sample]
seq_len, n_bodies, n_joints, n_dims = skeleton.shape
Expand Down

0 comments on commit 6723f1e

Please sign in to comment.