Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Human 3.6M Whole Body #72

Merged
merged 7 commits into from
Jan 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
101 changes: 101 additions & 0 deletions notebooks/human36m.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"DATA_DIR = '/media/kaseris/FastData/wetransfer_human36_2024-01-15_1137/human36'\n",
"DATASET = 'data_3d_h36m.npz'\n",
"\n",
"import os\n",
"import os.path as osp\n",
"\n",
"import numpy as np\n",
"\n",
"fname = osp.join(DATA_DIR, DATASET)\n",
"data = np.load(fname, allow_pickle=True)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"subjects = list(data['positions_3d'].item().keys())"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"total_frames = 0\n",
"for k, v in data['positions_3d'].item().items():\n",
" print(k)\n",
" for action, v in v.items():\n",
" print(action)\n",
" print(v.shape)\n",
" total_frames += v.shape[0]\n",
"print(f'Total frames: {total_frames:,}')"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"v = data['positions_3d'].item()['S1']['WalkDog 1']"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"v.shape"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import matplotlib.pyplot as plt\n",
"from mpl_toolkits.mplot3d import Axes3D\n",
"\n",
"fig = plt.figure()\n",
"ax = fig.add_subplot(111, projection='3d')\n",
"\n",
"ax.scatter(v[0, :, 0], v[0, :, 1], v[0, :, 2], cmap='viridis')"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "scraping",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.13"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
3 changes: 2 additions & 1 deletion src/skelcast/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,5 @@
TRANSFORMS = Registry()

from .dataset import NTURGBDCollateFn, NTURGBDDataset
from .transforms import MinMaxScaleTransform
from .transforms import MinMaxScaleTransform
from .human36mwb import Human36MWBDataset
59 changes: 2 additions & 57 deletions src/skelcast/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -299,62 +299,6 @@ def store_to_cache(self, cache_file: str) -> None:
with open(cache_file, 'wb') as f:
pickle.dump(self.skeleton_files_clean, f)
logging.info(f"Stored {len(self.skeleton_files_clean)} files to cache file {cache_file}.")


@DATASETS.register_module()
class Human36mDataset(Dataset):
def __init__(self, data_path, use_hourglass_detections=True, train=True) -> None:
self.data_path = data_path
self.use_hourglass_detections = use_hourglass_detections
self.train = train

self.train_inputs, self.test_inputs = [], []
self.act = []

if self.use_hourglass_detections:
train_2d_file = 'train_2d_ft.pth.tar'
test_2d_file = 'test_2d_ft.pth.tar'
else:
train_2d_file = 'train_2d.pth.tar'
train_2d_file = 'test_2d.pth.tar'

if self.train:
self.train_3d = torch.load(os.path.join(data_path, 'train_3d.pth.tar'))
self.train_2d = torch.load(os.path.join(data_path, train_2d_file))

for k2d in self.train_2d.keys():
(sub, act, fname) = k2d
k3d = k2d
k3d = (sub, act, fname[:-3]) if fname.endswith('-sh') else k3d
assert self.train_3d[k3d].shape[0] == self.train_2d[k2d].shape[0], f'(training) 3d and 2d shapes not matching'
self.train_inputs.append(self.train_3d[k3d])
self.act.append(act)

else:
self.test_3d = torch.load(os.path.join(data_path, 'test_3d.pth.tar'))
self.test_2d = torch.load(os.path.join(data_path, test_2d_file))
for k2d in self.test_2d.keys():
(sub, act, fname) = k2d
k3d = k2d
k3d = (sub, act, fname[:-3]) if fname.endswith('-sh') else k3d
assert self.test_2d[k2d].shape[0] == self.test_3d[k3d].shape[0], '(test) 3d and 2d shapes not matching'
self.test_inputs.append(self.test_3d[k3d])
self.act.append(act)

def __getitem__(self, index) -> Any:
if self.train:
# We want the sampeles to be returned as sequences
# i.e.: [seq_len, n_joints, 3]
x = torch.from_numpy(self.train_inputs[index]).float()
else:
x = torch.from_numpy(self.test_inputs[index]).float()
return x.view(-1, 16, 3), self.act[index]

def __len__(self):
if self.train:
return len(self.train_inputs)
else:
return len(self.test_inputs)


@DATASETS.register_module()
Expand Down Expand Up @@ -421,4 +365,5 @@ def hparam(self) -> dict:
return {
'DATA_history_length': self.history_length,
'DATA_prediction_horizon': self.prediction_horizon,
}
}

1 change: 1 addition & 0 deletions src/skelcast/data/human36mwb/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .human36mwb import Human36MWBDataset
122 changes: 122 additions & 0 deletions src/skelcast/data/human36mwb/camera.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
# Copyright (c) 2018-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
#

import numpy as np
import torch

from skelcast.data.human36mwb.quaternion import qrot, qinverse

def wrap(func, *args, unsqueeze=False):
"""
Wrap a torch function so it can be called with NumPy arrays.
Input and return types are seamlessly converted.
"""

# Convert input types where applicable
args = list(args)
for i, arg in enumerate(args):
if type(arg) == np.ndarray:
args[i] = torch.from_numpy(arg)
if unsqueeze:
args[i] = args[i].unsqueeze(0)

result = func(*args)

# Convert output types where applicable
if isinstance(result, tuple):
result = list(result)
for i, res in enumerate(result):
if type(res) == torch.Tensor:
if unsqueeze:
res = res.squeeze(0)
result[i] = res.numpy()
return tuple(result)
elif type(result) == torch.Tensor:
if unsqueeze:
result = result.squeeze(0)
return result.numpy()
else:
return result


def normalize_screen_coordinates(X, w, h):
assert X.shape[-1] == 2

# Normalize so that [0, w] is mapped to [-1, 1], while preserving the aspect ratio
return X/w*2 - [1, h/w]


def image_coordinates(X, w, h):
assert X.shape[-1] == 2

# Reverse camera frame normalization
return (X + [1, h/w])*w/2


def world_to_camera(X, R, t):
Rt = wrap(qinverse, R) # Invert rotation
return wrap(qrot, np.tile(Rt, (*X.shape[:-1], 1)), X - t) # Rotate and translate


def camera_to_world(X, R, t):
return wrap(qrot, np.tile(R, (*X.shape[:-1], 1)), X) + t


def project_to_2d(X, camera_params):
"""
Project 3D points to 2D using the Human3.6M camera projection function.
This is a differentiable and batched reimplementation of the original MATLAB script.

Arguments:
X -- 3D points in *camera space* to transform (N, *, 3)
camera_params -- intrinsic parameteres (N, 2+2+3+2=9)
"""
assert X.shape[-1] == 3
assert len(camera_params.shape) == 2
assert camera_params.shape[-1] == 9
assert X.shape[0] == camera_params.shape[0]

while len(camera_params.shape) < len(X.shape):
camera_params = camera_params.unsqueeze(1)

f = camera_params[..., :2]
c = camera_params[..., 2:4]
k = camera_params[..., 4:7]
p = camera_params[..., 7:]

XX = torch.clamp(X[..., :2] / X[..., 2:], min=-1, max=1)
r2 = torch.sum(XX[..., :2]**2, dim=len(XX.shape)-1, keepdim=True)

radial = 1 + torch.sum(k * torch.cat((r2, r2**2, r2**3), dim=len(r2.shape)-1), dim=len(r2.shape)-1, keepdim=True)
tan = torch.sum(p*XX, dim=len(XX.shape)-1, keepdim=True)

XXX = XX*(radial + tan) + p*r2

return f*XXX + c

def project_to_2d_linear(X, camera_params):
"""
Project 3D points to 2D using only linear parameters (focal length and principal point).

Arguments:
X -- 3D points in *camera space* to transform (N, *, 3)
camera_params -- intrinsic parameteres (N, 2+2+3+2=9)
"""
assert X.shape[-1] == 3
assert len(camera_params.shape) == 2
assert camera_params.shape[-1] == 9
assert X.shape[0] == camera_params.shape[0]

while len(camera_params.shape) < len(X.shape):
camera_params = camera_params.unsqueeze(1)

f = camera_params[..., :2]
c = camera_params[..., 2:4]

XX = torch.clamp(X[..., :2] / X[..., 2:], min=-1, max=1)

return f*XX + c
Loading
Loading