diff --git a/src/skelcast/models/module.py b/src/skelcast/models/module.py index 622e22a..9c534b4 100644 --- a/src/skelcast/models/module.py +++ b/src/skelcast/models/module.py @@ -28,6 +28,12 @@ def validation_step(self, *args, **kwargs): Implements a validation step of a module """ pass + + def from_pretrained(model_path=None): + """ + Implements a method to load a pretrained model + """ + pass def gradient_flow(self): """ diff --git a/src/skelcast/models/transformers/sttf.py b/src/skelcast/models/transformers/sttf.py index 6bbf8e1..afc56e5 100644 --- a/src/skelcast/models/transformers/sttf.py +++ b/src/skelcast/models/transformers/sttf.py @@ -257,4 +257,12 @@ def predict(self, sample, n_steps=10, observe_from_to=[10]): to_ += 1 sample_input = sample[:, from_:to_, ...] - return torch.stack(forecasted, dim=1) \ No newline at end of file + return torch.stack(forecasted, dim=1) + + def from_pretrained(self, model_path=None): + if model_path is None: + raise ValueError('`model_path` must be provided.') + checkpoint = torch.load(model_path) + model_state_dict = checkpoint['model_state_dict'] + self.load_state_dict(model_state_dict) + return self \ No newline at end of file diff --git a/tools/visualize_skel_movement.py b/tools/visualize_skel_movement.py index 8c7709a..6242f26 100644 --- a/tools/visualize_skel_movement.py +++ b/tools/visualize_skel_movement.py @@ -27,9 +27,10 @@ max_number_of_bodies=1, transforms=tf) model = SpatioTemporalTransformer(n_joints=25, d_model=256, n_blocks=3, n_heads=8, d_head=16, mlp_dim=512, loss_fn=nn.SmoothL1Loss(), dropout=0.5) # TODO: Remove the hard coding of the checkpoint path - checkpoint = torch.load('/home/kaseris/Documents/mount/checkpoints_forecasting/presto-class/checkpoint_epoch_16_2024-01-05_092620.pt') - model_state_dict = checkpoint['model_state_dict'] - model.load_state_dict(model_state_dict) + # checkpoint = torch.load('/home/kaseris/Documents/mount/checkpoints_forecasting/presto-class/checkpoint_epoch_16_2024-01-05_092620.pt') + # model_state_dict = checkpoint['model_state_dict'] + # model.load_state_dict(model_state_dict) + model.from_pretrained('/home/kaseris/Documents/mount/checkpoints_forecasting/presto-class/checkpoint_epoch_16_2024-01-05_092620.pt') model = model.to('cpu') skeleton, label = dataset[args.sample] seq_len, n_bodies, n_joints, n_dims = skeleton.shape