Skip to content

Commit

Permalink
Remove old implementation of Human3.6m from the dataset module. May b…
Browse files Browse the repository at this point in the history
…ecome too cluttered.
  • Loading branch information
kaseris committed Jan 23, 2024
1 parent 12a7bf4 commit 5b97eca
Showing 1 changed file with 2 additions and 57 deletions.
59 changes: 2 additions & 57 deletions src/skelcast/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -299,62 +299,6 @@ def store_to_cache(self, cache_file: str) -> None:
with open(cache_file, 'wb') as f:
pickle.dump(self.skeleton_files_clean, f)
logging.info(f"Stored {len(self.skeleton_files_clean)} files to cache file {cache_file}.")


@DATASETS.register_module()
class Human36mDataset(Dataset):
def __init__(self, data_path, use_hourglass_detections=True, train=True) -> None:
self.data_path = data_path
self.use_hourglass_detections = use_hourglass_detections
self.train = train

self.train_inputs, self.test_inputs = [], []
self.act = []

if self.use_hourglass_detections:
train_2d_file = 'train_2d_ft.pth.tar'
test_2d_file = 'test_2d_ft.pth.tar'
else:
train_2d_file = 'train_2d.pth.tar'
train_2d_file = 'test_2d.pth.tar'

if self.train:
self.train_3d = torch.load(os.path.join(data_path, 'train_3d.pth.tar'))
self.train_2d = torch.load(os.path.join(data_path, train_2d_file))

for k2d in self.train_2d.keys():
(sub, act, fname) = k2d
k3d = k2d
k3d = (sub, act, fname[:-3]) if fname.endswith('-sh') else k3d
assert self.train_3d[k3d].shape[0] == self.train_2d[k2d].shape[0], f'(training) 3d and 2d shapes not matching'
self.train_inputs.append(self.train_3d[k3d])
self.act.append(act)

else:
self.test_3d = torch.load(os.path.join(data_path, 'test_3d.pth.tar'))
self.test_2d = torch.load(os.path.join(data_path, test_2d_file))
for k2d in self.test_2d.keys():
(sub, act, fname) = k2d
k3d = k2d
k3d = (sub, act, fname[:-3]) if fname.endswith('-sh') else k3d
assert self.test_2d[k2d].shape[0] == self.test_3d[k3d].shape[0], '(test) 3d and 2d shapes not matching'
self.test_inputs.append(self.test_3d[k3d])
self.act.append(act)

def __getitem__(self, index) -> Any:
if self.train:
# We want the sampeles to be returned as sequences
# i.e.: [seq_len, n_joints, 3]
x = torch.from_numpy(self.train_inputs[index]).float()
else:
x = torch.from_numpy(self.test_inputs[index]).float()
return x.view(-1, 16, 3), self.act[index]

def __len__(self):
if self.train:
return len(self.train_inputs)
else:
return len(self.test_inputs)


@DATASETS.register_module()
Expand Down Expand Up @@ -421,4 +365,5 @@ def hparam(self) -> dict:
return {
'DATA_history_length': self.history_length,
'DATA_prediction_horizon': self.prediction_horizon,
}
}

0 comments on commit 5b97eca

Please sign in to comment.