Skip to content

Commit

Permalink
Merge pull request #72 from kaseris/datasets/human36m
Browse files Browse the repository at this point in the history
Human 3.6M Whole Body
  • Loading branch information
kaseris authored Jan 23, 2024
2 parents 38c7e96 + 5b97eca commit e0810bb
Show file tree
Hide file tree
Showing 9 changed files with 1,104 additions and 58 deletions.
101 changes: 101 additions & 0 deletions notebooks/human36m.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"DATA_DIR = '/media/kaseris/FastData/wetransfer_human36_2024-01-15_1137/human36'\n",
"DATASET = 'data_3d_h36m.npz'\n",
"\n",
"import os\n",
"import os.path as osp\n",
"\n",
"import numpy as np\n",
"\n",
"fname = osp.join(DATA_DIR, DATASET)\n",
"data = np.load(fname, allow_pickle=True)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"subjects = list(data['positions_3d'].item().keys())"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"total_frames = 0\n",
"for k, v in data['positions_3d'].item().items():\n",
" print(k)\n",
" for action, v in v.items():\n",
" print(action)\n",
" print(v.shape)\n",
" total_frames += v.shape[0]\n",
"print(f'Total frames: {total_frames:,}')"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"v = data['positions_3d'].item()['S1']['WalkDog 1']"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"v.shape"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import matplotlib.pyplot as plt\n",
"from mpl_toolkits.mplot3d import Axes3D\n",
"\n",
"fig = plt.figure()\n",
"ax = fig.add_subplot(111, projection='3d')\n",
"\n",
"ax.scatter(v[0, :, 0], v[0, :, 1], v[0, :, 2], cmap='viridis')"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "scraping",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.13"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
3 changes: 2 additions & 1 deletion src/skelcast/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,5 @@
TRANSFORMS = Registry()

from .dataset import NTURGBDCollateFn, NTURGBDDataset
from .transforms import MinMaxScaleTransform
from .transforms import MinMaxScaleTransform
from .human36mwb import Human36MWBDataset
59 changes: 2 additions & 57 deletions src/skelcast/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -299,62 +299,6 @@ def store_to_cache(self, cache_file: str) -> None:
with open(cache_file, 'wb') as f:
pickle.dump(self.skeleton_files_clean, f)
logging.info(f"Stored {len(self.skeleton_files_clean)} files to cache file {cache_file}.")


@DATASETS.register_module()
class Human36mDataset(Dataset):
def __init__(self, data_path, use_hourglass_detections=True, train=True) -> None:
self.data_path = data_path
self.use_hourglass_detections = use_hourglass_detections
self.train = train

self.train_inputs, self.test_inputs = [], []
self.act = []

if self.use_hourglass_detections:
train_2d_file = 'train_2d_ft.pth.tar'
test_2d_file = 'test_2d_ft.pth.tar'
else:
train_2d_file = 'train_2d.pth.tar'
train_2d_file = 'test_2d.pth.tar'

if self.train:
self.train_3d = torch.load(os.path.join(data_path, 'train_3d.pth.tar'))
self.train_2d = torch.load(os.path.join(data_path, train_2d_file))

for k2d in self.train_2d.keys():
(sub, act, fname) = k2d
k3d = k2d
k3d = (sub, act, fname[:-3]) if fname.endswith('-sh') else k3d
assert self.train_3d[k3d].shape[0] == self.train_2d[k2d].shape[0], f'(training) 3d and 2d shapes not matching'
self.train_inputs.append(self.train_3d[k3d])
self.act.append(act)

else:
self.test_3d = torch.load(os.path.join(data_path, 'test_3d.pth.tar'))
self.test_2d = torch.load(os.path.join(data_path, test_2d_file))
for k2d in self.test_2d.keys():
(sub, act, fname) = k2d
k3d = k2d
k3d = (sub, act, fname[:-3]) if fname.endswith('-sh') else k3d
assert self.test_2d[k2d].shape[0] == self.test_3d[k3d].shape[0], '(test) 3d and 2d shapes not matching'
self.test_inputs.append(self.test_3d[k3d])
self.act.append(act)

def __getitem__(self, index) -> Any:
if self.train:
# We want the sampeles to be returned as sequences
# i.e.: [seq_len, n_joints, 3]
x = torch.from_numpy(self.train_inputs[index]).float()
else:
x = torch.from_numpy(self.test_inputs[index]).float()
return x.view(-1, 16, 3), self.act[index]

def __len__(self):
if self.train:
return len(self.train_inputs)
else:
return len(self.test_inputs)


@DATASETS.register_module()
Expand Down Expand Up @@ -421,4 +365,5 @@ def hparam(self) -> dict:
return {
'DATA_history_length': self.history_length,
'DATA_prediction_horizon': self.prediction_horizon,
}
}

1 change: 1 addition & 0 deletions src/skelcast/data/human36mwb/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .human36mwb import Human36MWBDataset
122 changes: 122 additions & 0 deletions src/skelcast/data/human36mwb/camera.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
# Copyright (c) 2018-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
#

import numpy as np
import torch

from skelcast.data.human36mwb.quaternion import qrot, qinverse

def wrap(func, *args, unsqueeze=False):
"""
Wrap a torch function so it can be called with NumPy arrays.
Input and return types are seamlessly converted.
"""

# Convert input types where applicable
args = list(args)
for i, arg in enumerate(args):
if type(arg) == np.ndarray:
args[i] = torch.from_numpy(arg)
if unsqueeze:
args[i] = args[i].unsqueeze(0)

result = func(*args)

# Convert output types where applicable
if isinstance(result, tuple):
result = list(result)
for i, res in enumerate(result):
if type(res) == torch.Tensor:
if unsqueeze:
res = res.squeeze(0)
result[i] = res.numpy()
return tuple(result)
elif type(result) == torch.Tensor:
if unsqueeze:
result = result.squeeze(0)
return result.numpy()
else:
return result


def normalize_screen_coordinates(X, w, h):
assert X.shape[-1] == 2

# Normalize so that [0, w] is mapped to [-1, 1], while preserving the aspect ratio
return X/w*2 - [1, h/w]


def image_coordinates(X, w, h):
assert X.shape[-1] == 2

# Reverse camera frame normalization
return (X + [1, h/w])*w/2


def world_to_camera(X, R, t):
Rt = wrap(qinverse, R) # Invert rotation
return wrap(qrot, np.tile(Rt, (*X.shape[:-1], 1)), X - t) # Rotate and translate


def camera_to_world(X, R, t):
return wrap(qrot, np.tile(R, (*X.shape[:-1], 1)), X) + t


def project_to_2d(X, camera_params):
"""
Project 3D points to 2D using the Human3.6M camera projection function.
This is a differentiable and batched reimplementation of the original MATLAB script.
Arguments:
X -- 3D points in *camera space* to transform (N, *, 3)
camera_params -- intrinsic parameteres (N, 2+2+3+2=9)
"""
assert X.shape[-1] == 3
assert len(camera_params.shape) == 2
assert camera_params.shape[-1] == 9
assert X.shape[0] == camera_params.shape[0]

while len(camera_params.shape) < len(X.shape):
camera_params = camera_params.unsqueeze(1)

f = camera_params[..., :2]
c = camera_params[..., 2:4]
k = camera_params[..., 4:7]
p = camera_params[..., 7:]

XX = torch.clamp(X[..., :2] / X[..., 2:], min=-1, max=1)
r2 = torch.sum(XX[..., :2]**2, dim=len(XX.shape)-1, keepdim=True)

radial = 1 + torch.sum(k * torch.cat((r2, r2**2, r2**3), dim=len(r2.shape)-1), dim=len(r2.shape)-1, keepdim=True)
tan = torch.sum(p*XX, dim=len(XX.shape)-1, keepdim=True)

XXX = XX*(radial + tan) + p*r2

return f*XXX + c

def project_to_2d_linear(X, camera_params):
"""
Project 3D points to 2D using only linear parameters (focal length and principal point).
Arguments:
X -- 3D points in *camera space* to transform (N, *, 3)
camera_params -- intrinsic parameteres (N, 2+2+3+2=9)
"""
assert X.shape[-1] == 3
assert len(camera_params.shape) == 2
assert camera_params.shape[-1] == 9
assert X.shape[0] == camera_params.shape[0]

while len(camera_params.shape) < len(X.shape):
camera_params = camera_params.unsqueeze(1)

f = camera_params[..., :2]
c = camera_params[..., 2:4]

XX = torch.clamp(X[..., :2] / X[..., 2:], min=-1, max=1)

return f*XX + c
Loading

0 comments on commit e0810bb

Please sign in to comment.