Skip to content

Commit

Permalink
Merge pull request #67 from kaseris/datasets/human36m
Browse files Browse the repository at this point in the history
Human 3.6m Dataset
  • Loading branch information
kaseris committed Jan 15, 2024
2 parents 8669520 + cf2807f commit 232394e
Showing 1 changed file with 56 additions and 0 deletions.
56 changes: 56 additions & 0 deletions src/skelcast/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -299,3 +299,59 @@ def store_to_cache(self, cache_file: str) -> None:
with open(cache_file, 'wb') as f:
pickle.dump(self.skeleton_files_clean, f)
logging.info(f"Stored {len(self.skeleton_files_clean)} files to cache file {cache_file}.")


@DATASETS.register_module()
class Human36mDataset(Dataset):
def __init__(self, data_path, use_hourglass_detections=True, train=True) -> None:
self.data_path = data_path
self.use_hourglass_detections = use_hourglass_detections
self.train = train

self.train_inputs, self.test_inputs = [], []
self.act = []

if self.use_hourglass_detections:
train_2d_file = 'train_2d_ft.pth.tar'
test_2d_file = 'test_2d_ft.pth.tar'
else:
train_2d_file = 'train_2d.pth.tar'
train_2d_file = 'test_2d.pth.tar'

if self.train:
self.train_3d = torch.load(os.path.join(data_path, 'train_3d.pth.tar'))
self.train_2d = torch.load(os.path.join(data_path, train_2d_file))

for k2d in self.train_2d.keys():
(sub, act, fname) = k2d
k3d = k2d
k3d = (sub, act, fname[:-3]) if fname.endswith('-sh') else k3d
assert self.train_3d[k3d].shape[0] == self.train_2d[k2d].shape[0], f'(training) 3d and 2d shapes not matching'
self.train_inputs.append(self.train_3d[k3d])
self.act.append(act)

else:
self.test_3d = torch.load(os.path.join(data_path, 'test_3d.pth.tar'))
self.test_2d = torch.load(os.path.join(data_path, test_2d_file))
for k2d in self.test_2d.keys():
(sub, act, fname) = k2d
k3d = k2d
k3d = (sub, act, fname[:-3]) if fname.endswith('-sh') else k3d
assert self.test_2d[k2d].shape[0] == self.test_3d[k3d].shape[0], '(test) 3d and 2d shapes not matching'
self.test_inputs.append(self.test_3d[k3d])
self.act.append(act)

def __getitem__(self, index) -> Any:
if self.train:
# We want the sampeles to be returned as sequences
# i.e.: [seq_len, n_joints, 3]
x = torch.from_numpy(self.train_inputs[index]).float()
else:
x = torch.from_numpy(self.test_inputs[index]).float()
return x.view(-1, 16, 3), self.act[index]

def __len__(self):
if self.train:
return len(self.train_inputs)
else:
return len(self.test_inputs)

0 comments on commit 232394e

Please sign in to comment.