Skip to content

Commit

Permalink
Return the sequence action
Browse files Browse the repository at this point in the history
  • Loading branch information
kaseris committed Jan 15, 2024
1 parent c92b546 commit cf2807f
Showing 1 changed file with 4 additions and 1 deletion.
5 changes: 4 additions & 1 deletion src/skelcast/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -309,6 +309,7 @@ def __init__(self, data_path, use_hourglass_detections=True, train=True) -> None
self.train = train

self.train_inputs, self.test_inputs = [], []
self.act = []

if self.use_hourglass_detections:
train_2d_file = 'train_2d_ft.pth.tar'
Expand All @@ -327,6 +328,7 @@ def __init__(self, data_path, use_hourglass_detections=True, train=True) -> None
k3d = (sub, act, fname[:-3]) if fname.endswith('-sh') else k3d
assert self.train_3d[k3d].shape[0] == self.train_2d[k2d].shape[0], f'(training) 3d and 2d shapes not matching'
self.train_inputs.append(self.train_3d[k3d])
self.act.append(act)

else:
self.test_3d = torch.load(os.path.join(data_path, 'test_3d.pth.tar'))
Expand All @@ -337,6 +339,7 @@ def __init__(self, data_path, use_hourglass_detections=True, train=True) -> None
k3d = (sub, act, fname[:-3]) if fname.endswith('-sh') else k3d
assert self.test_2d[k2d].shape[0] == self.test_3d[k3d].shape[0], '(test) 3d and 2d shapes not matching'
self.test_inputs.append(self.test_3d[k3d])
self.act.append(act)

def __getitem__(self, index) -> Any:
if self.train:
Expand All @@ -345,7 +348,7 @@ def __getitem__(self, index) -> Any:
x = torch.from_numpy(self.train_inputs[index]).float()
else:
x = torch.from_numpy(self.test_inputs[index]).float()
return x.view(-1, 16, 3)
return x.view(-1, 16, 3), self.act[index]

def __len__(self):
if self.train:
Expand Down

0 comments on commit cf2807f

Please sign in to comment.