Skip to content

Commit

Permalink
Merge pull request #63 from kaseris/feature/forecast-sttf
Browse files Browse the repository at this point in the history
Add forecasting interface to the SpatioTemporalTransformer module
  • Loading branch information
kaseris committed Jan 5, 2024
2 parents 85c1a84 + 3c44c33 commit 07ad810
Showing 1 changed file with 35 additions and 2 deletions.
37 changes: 35 additions & 2 deletions src/skelcast/models/transformers/sttf.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,5 +223,38 @@ def validation_step(self, *args, **kwargs):
with torch.no_grad():
return self.training_step(*args, **kwargs)

def predict(self, *args, **kwargs):
pass
def predict(self, sample, n_steps=10, observe_from_to=[10]):
"""Predicts `n_steps` into the future given the sample.
Args:
- sample `torch.Tensor`: The sample to predict from. The shape must be (seq_len, n_skels, n_joints, 3)
- n_steps `int`: The number of steps to predict into the future
- observe_from_to `list`: The start and end of the observation. If the list contains only one element, then it is the end of the observation.
"""
if len(observe_from_to) == 1:
from_ = 0
to_ = observe_from_to[0]
else:
if len(observe_from_to) > 2:
raise ValueError('`observe_from_to` must be a list of length 1 or 2.')
if observe_from_to[0] > observe_from_to[1]:
raise ValueError('The start of observation must be before the end of observation.')
from_, to_ = observe_from_to

sample = sample.squeeze(1).unsqueeze(0)
sample_input = sample[:, from_:to_, ...]
forecasted = []
self.eval()
with torch.no_grad():
for _ in range(n_steps):
prediction = self(sample_input.to(torch.float32))
forecasted.append(prediction[:, -1:, ...].squeeze(1).detach())
# Roll the sample input and replace the last element with the last element of the prediction
# sample_input = torch.roll(sample_input, -1, dims=1)
# print(sample_input)
# sample_input[:, -1, ...] = prediction[:, -1:, ...].unsqueeze(1)
from_ += 1
to_ += 1
sample_input = sample[:, from_:to_, ...]

return torch.stack(forecasted, dim=1)

0 comments on commit 07ad810

Please sign in to comment.