From ea9002c4546053fed5bdd2b086114c9e28003708 Mon Sep 17 00:00:00 2001 From: Michail Kaseris Date: Tue, 16 Jan 2024 16:25:58 +0200 Subject: [PATCH 1/3] Amass dataset torch interface --- src/skelcast/data/dataset.py | 69 +++++++++++++++++++++++++++++++++++- 1 file changed, 68 insertions(+), 1 deletion(-) diff --git a/src/skelcast/data/dataset.py b/src/skelcast/data/dataset.py index 8a03201..6534cc9 100644 --- a/src/skelcast/data/dataset.py +++ b/src/skelcast/data/dataset.py @@ -354,4 +354,71 @@ def __len__(self): if self.train: return len(self.train_inputs) else: - return len(self.test_inputs) \ No newline at end of file + return len(self.test_inputs) + + +@DATASETS.register_module() +class AMASSTorchDataset(Dataset): + def __init__(self, + index, + pose_data, + history_length: int, + prediction_horizon: int, + trans_data=None, + transform=lambda x: x, + window=None, + **kwargs): + self.history_length = history_length + self.prediction_horizon = prediction_horizon + self.transform = transform + + self._index = index + self._pose_data = pose_data + self._trans_data = trans_data + self._window = window + + + self.dataset_index = [] + + self.index_data() + + def index_data(self): + seq_lengths = self._index[:, 1] - self._index[:, 0] + for j in range(seq_lengths.shape[0]): + if self._window: + for i in range(0, seq_lengths[j] - self._window[0] + 1, self._window[1]): + self.dataset_index += [(j, i)] + else: + for i in range(seq_lengths[j] - self.history_length - self.prediction_horizon): + self.dataset_index += [(j, i)] + + def __getitem__(self, item): + i_idx, i = self.dataset_index[item] + ts_pose = self._pose_data[self._index[i_idx, 0] + i: self._index[i_idx, 0] + i + self.history_length + self.prediction_horizon] + ts_pose = self.transform(ts_pose) + ts_pose = torch.tensor(ts_pose) + + if hasattr(self, 'dataset_name'): + dp = (self.dataset_name, i_idx, i) + else: + dp = ("_", i_idx, i) + + if self._trans_data is not None: + ts_trans = self._trans_data[self._index[i_idx, 0] + i: self._index[i_idx, 0] + i + self.history_length + self.prediction_horizon] + ts_trans = torch.tensor(ts_trans) + return (ts_pose[:self.history_length], + ts_pose[self.history_length: self.history_length + self.prediction_horizon], + ts_trans[:self.history_length], + ts_trans[self.history_length: self.history_length + self.prediction_horizon]) + + return (ts_pose[:self.history_length], + ts_pose[self.history_length: self.history_length + self.prediction_horizon]) + + def __len__(self): + return len(self.dataset_index) + + def hparam(self) -> dict: + return { + 'DATA_history_length': self.history_length, + 'DATA_prediction_horizon': self.prediction_horizon, + } \ No newline at end of file From fc0ee2b2405f3b7567d6ac4eef514dd5e407d1db Mon Sep 17 00:00:00 2001 From: Michail Kaseris Date: Tue, 16 Jan 2024 16:27:55 +0200 Subject: [PATCH 2/3] Quaternion data model --- src/skelcast/utils/quaternion.py | 357 +++++++++++++++++++++++++++++++ 1 file changed, 357 insertions(+) create mode 100644 src/skelcast/utils/quaternion.py diff --git a/src/skelcast/utils/quaternion.py b/src/skelcast/utils/quaternion.py new file mode 100644 index 0000000..daaf678 --- /dev/null +++ b/src/skelcast/utils/quaternion.py @@ -0,0 +1,357 @@ +import torch +from math import pi + + +class Quaternion(object): + def __init__(self, *args, **kwargs): + s = len(args) + if s == 0: + if ("axis" in kwargs) and ("angle" in kwargs): + axis = kwargs["axis"] + angle = kwargs["angle"] + self._q = Quaternion._from_axis_angle(axis, angle).q + elif ("axis" in kwargs) and ("rodriguez_parameter" in kwargs): + axis = kwargs["axis"] + rodriguez_parameter = kwargs["rodriguez_parameter"] + self._q = Quaternion._from_rodrigues_vector(axis, rodriguez_parameter).q + else: + q = args[0] + if Quaternion.is_quaternion(q): + self._q = q._q + else: + assert q.shape[-1] == 4, 'Quaternion has to be of dimension 4' + self._q = q + + @classmethod + def _from_axis_angle(cls, axis, angle): + """Initialise from axis and angle representation + Create a Quaternion by specifying the 3-vector rotation axis and rotation + angle (in radians) from which the quaternion's rotation should be created. + Params: + axis: a valid numpy 3-vector + angle: a real valued angle in radians + """ + norm = axis.square().sum(-1).sqrt().unsqueeze(-1) + axis = axis / norm.clamp_min(1e-12) + theta = angle.unsqueeze(-1) / 2.0 + r = torch.where(norm > 1e-12, torch.cos(theta), torch.ones_like(theta)) + i = torch.where(norm > 1e-12, axis * torch.sin(theta), torch.zeros_like(axis)) + q = torch.cat([r, i], dim=-1) + return cls(q) + + @classmethod + def _from_rodrigues_vector(cls, axis, rodrigues_parameter): + norm = axis.square().sum(-1).sqrt().unsqueeze(-1) + axis = axis / norm.clamp_min(1e-12) + theta = torch.atan(rodrigues_parameter).unsqueeze(-1) + r = torch.where(norm > 1e-12, torch.cos(theta), torch.ones_like(theta)) + i = torch.where(norm > 1e-12, axis * torch.sin(theta), torch.zeros_like(axis)) + q = torch.cat([r, i], dim=-1) + return cls(q) + + def __mul__(self, other): + if Quaternion.is_quaternion(other): + return self.__class__(torch.matmul(self._q_matrix(), other._q.unsqueeze(-1)).squeeze(-1)) + return self * self.__class__(other) + + @classmethod + def mul_(cls, q1, q2): + return (cls(q1) * cls(q2)).q + + def __repr__(self): + return f"Quaternion: {self._q.__repr__()}" + + def _q_matrix(self): + """Matrix representation of quaternion for multiplication purposes. + """ + return torch.stack([ + torch.stack([self._q[..., 0], -self._q[..., 1], -self._q[..., 2], -self._q[..., 3]], dim=-1), + torch.stack([self._q[..., 1], self._q[..., 0], -self._q[..., 3], self._q[..., 2]], dim=-1), + torch.stack([self._q[..., 2], self._q[..., 3], self._q[..., 0], -self._q[..., 1]], dim=-1), + torch.stack([self._q[..., 3], -self._q[..., 2], self._q[..., 1], self._q[..., 0]], dim=-1)], dim=-2) + + def _q_bar_matrix(self): + """Matrix representation of quaternion for multiplication purposes. + """ + return torch.stack([ + torch.stack([self._q[..., 0], -self._q[..., 1], -self._q[..., 2], -self._q[..., 3]], dim=-1), + torch.stack([self._q[..., 1], self._q[..., 0], self._q[..., 3], -self._q[..., 2]], dim=-1), + torch.stack([self._q[..., 2], -self._q[..., 3], self._q[..., 0], self._q[..., 1]], dim=-1), + torch.stack([self._q[..., 3], self._q[..., 2], -self._q[..., 1], self._q[..., 0]], dim=-1)], dim=-2) + + def _rotate_quaternion(self, q): + """Rotate a quaternion vector using the stored rotation. + + Params: + q: The vector to be rotated, in quaternion form (0 + xi + yj + kz) + + Returns: + A Quaternion object representing the rotated vector in quaternion from (0 + xi + yj + kz) + """ + #self._normalize() + return self.__class__(self * q * self.conjugate) + + @classmethod + def rotate_(cls, q, v): + return cls(q).rotate(v) + + def rotate(self, vector): + """Rotate a 3D vector by the rotation stored in the Quaternion object. + + Params: + vector: A 3-vector specified as any ordered sequence of 3 real numbers corresponding to x, y, and z values. + Some types that are recognised are: numpy arrays, lists and tuples. + A 3-vector can also be represented by a Quaternion object who's scalar part is 0 and vector part is the required 3-vector. + Thus it is possible to call `Quaternion.rotate(q)` with another quaternion object as an itorch.t. + + Returns: + The rotated vector returned as the same type it was specified at itorch.t. + + Raises: + TypeError: if any of the vector elements cannot be converted to a real number. + ValueError: if `vector` cannot be interpreted as a 3-vector or a Quaternion object. + + """ + if Quaternion.is_quaternion(vector): + return self._rotate_quaternion(vector) + q = Quaternion(torch.cat([torch.zeros_like(vector[..., [0]]), vector], dim=-1)) + a = self._rotate_quaternion(q).vector + return a + + @classmethod + def conjugate_(cls, q): + return cls(q).conjugate.q + + @property + def conjugate(self): + """Quaternion conjugate, encapsulated in a new instance. + For a unit quaternion, this is the same as the inverse. + Returns: + A new Quaternion object clone with its vector part negated + """ + return self.__class__(torch.cat([self.scalar.unsqueeze(-1), -self.vector], dim=-1)) + + def _normalize(self): + """Object is guaranteed to be a unit quaternion after calling this + operation UNLESS the object is equivalent to Quaternion(0) + """ + self._q = torch.nn.functional.normalize(self._q, dim=-1) + + @property + def scalar(self): + """ Return the real or scalar component of the quaternion object. + + Returns: + A real number i.e. float + """ + return self._q[..., 0] + + @property + def vector(self): + """ Return the imaginary or vector component of the quaternion object. + + Returns: + A numpy 3-array of floats. NOT guaranteed to be a unit vector + """ + return self._q[..., 1:] + + @classmethod + def rotation_matrix_(cls, q): + return cls(q).rotation_matrix + + @property + def rotation_matrix(self): + """Get the 3x3 rotation matrix equivalent of the quaternion rotation. + + Returns: + A 3x3 orthogonal rotation matrix as a 3x3 Numpy array + + Note: + This feature only makes sense when referring to a unit quaternion. + Calling this method will implicitly normalise the Quaternion object to a unit quaternion if it is not already one. + + """ + #self._normalize() + product_matrix = torch.matmul(self._q_matrix(), self._q_bar_matrix().conj().transpose(-2, -1)) + return product_matrix[..., 1:, 1:] + + @property + def normalized(self): + """Get a unit quaternion (versor) copy of this Quaternion object. + + A unit quaternion has a `norm` of 1.0 + + Returns: + A new Quaternion object clone that is guaranteed to be a unit quaternion + """ + q = Quaternion(self._q) + q._normalize() + return q + + @property + def q(self): + return self._q + + @classmethod + def euler_angle_(cls, q, order, epsilon=0.): + return cls(q).euler_angle(order, epsilon) + + def euler_angle(self, order: str = 'zyx', epsilon: float = 0.): + """ + Convert quaternion(s) q to Euler angles. + """ + assert self.q.shape[-1] == 4 + + original_shape = list(self.q.shape) + original_shape[-1] = 3 + q = self.q.view(-1, 4) + + q0 = q[:, 0] + q1 = q[:, 1] + q2 = q[:, 2] + q3 = q[:, 3] + + if order == 'xyz': + x = torch.atan2(2 * (q0 * q1 - q2 * q3), 1 - 2 * (q1 * q1 + q2 * q2)) + y = torch.asin(torch.clamp(2 * (q1 * q3 + q0 * q2), -1 + epsilon, 1 - epsilon)) + z = torch.atan2(2 * (q0 * q3 - q1 * q2), 1 - 2 * (q2 * q2 + q3 * q3)) + elif order == 'yzx': + x = torch.atan2(2 * (q0 * q1 - q2 * q3), 1 - 2 * (q1 * q1 + q3 * q3)) + y = torch.atan2(2 * (q0 * q2 - q1 * q3), 1 - 2 * (q2 * q2 + q3 * q3)) + z = torch.asin(torch.clamp(2 * (q1 * q2 + q0 * q3), -1 + epsilon, 1 - epsilon)) + elif order == 'zxy': + x = torch.asin(torch.clamp(2 * (q0 * q1 + q2 * q3), -1 + epsilon, 1 - epsilon)) + y = torch.atan2(2 * (q0 * q2 - q1 * q3), 1 - 2 * (q1 * q1 + q2 * q2)) + z = torch.atan2(2 * (q0 * q3 - q1 * q2), 1 - 2 * (q1 * q1 + q3 * q3)) + elif order == 'xzy': + x = torch.atan2(2 * (q0 * q1 + q2 * q3), 1 - 2 * (q1 * q1 + q3 * q3)) + y = torch.atan2(2 * (q0 * q2 + q1 * q3), 1 - 2 * (q2 * q2 + q3 * q3)) + z = torch.asin(torch.clamp(2 * (q0 * q3 - q1 * q2), -1 + epsilon, 1 - epsilon)) + elif order == 'yxz': + x = torch.asin(torch.clamp(2 * (q0 * q1 - q2 * q3), -1 + epsilon, 1 - epsilon)) + y = torch.atan2(2 * (q1 * q3 + q0 * q2), 1 - 2 * (q1 * q1 + q2 * q2)) + z = torch.atan2(2 * (q1 * q2 + q0 * q3), 1 - 2 * (q1 * q1 + q3 * q3)) + elif order == 'zyx': + x = torch.atan2(2 * (q0 * q1 + q2 * q3), 1 - 2 * (q1 * q1 + q2 * q2)) + y = torch.asin(torch.clamp(2 * (q0 * q2 - q1 * q3), -1 + epsilon, 1 - epsilon)) + z = torch.atan2(2 * (q0 * q3 + q1 * q2), 1 - 2 * (q2 * q2 + q3 * q3)) + else: + raise + + return torch.stack((x, y, z), dim=1).view(original_shape) + + def get_axis(self): + """Get the axis or vector about which the quaternion rotation occurs + For a null rotation (a purely real quaternion), the rotation angle will + always be `0`, but the rotation axis is undefined. + It is by default assumed to be `[0, 0, 0]`. + Params: + undefined: [optional] specify the axis vector that should define a null rotation. + This is geometrically meaningless, and could be any of an infinite set of vectors, + but can be specified if the default (`[0, 0, 0]`) causes undesired behaviour. + Returns: + A Numpy unit 3-vector describing the Quaternion object's axis of rotation. + Note: + This feature only makes sense when referring to a unit quaternion. + Calling this method will implicitly normalise the Quaternion object to a unit quaternion if it is not already one. + """ + tolerance = 1e-12 + #self._normalize() + norm = torch.norm(self.vector, dim=-1).unsqueeze(-1)#self.vector.square().sum(-1).sqrt().unsqueeze(-1) + return torch.where(norm > tolerance, torch.nn.functional.normalize(self.vector, dim=-1), torch.zeros_like(self.vector)) + + @property + def axis(self): + return self.get_axis() + + def _wrap_angle(self, theta): + """Helper method: Wrap any angle to lie between -pi and pi + Odd multiples of pi are wrapped to +pi (as opposed to -pi) + """ + return torch.remainder(theta + pi, 2 * pi) - pi + + @property + def angle(self): + """Get the angle (in radians) describing the magnitude of the quaternion rotation about its rotation axis. + This is guaranteed to be within the range (-pi:pi) with the direction of + rotation indicated by the sign. + When a particular rotation describes a 180 degree rotation about an arbitrary + axis vector `v`, the conversion to axis / angle representation may jump + discontinuously between all permutations of `(-pi, pi)` and `(-v, v)`, + each being geometrically equivalent (see Note in documentation). + Returns: + A real number in the range (-pi:pi) describing the angle of rotation + in radians about a Quaternion object's axis of rotation. + Note: + This feature only makes sense when referring to a unit quaternion. + Calling this method will implicitly normalise the Quaternion object to a unit quaternion if it is not already one. + """ + #self._normalize() + norm = torch.norm(self.vector, dim=-1) + return self._wrap_angle(2.0 * torch.atan2(norm, self.scalar)) + + @property + def axis_angle(self): + return self.angle.unsqueeze(-1) * self.axis + + @property + def rodriguez_vector(self): + return torch.tan(self.angle.unsqueeze(-1) / 2.) * self.axis + + @staticmethod + def is_quaternion(other): + return 'Quaternion' in other.__class__.__name__ + + @staticmethod + def qfix_(q: torch.Tensor) -> torch.Tensor: + """ + Enforce quaternion continuity across the time dimension by selecting + the representation (q or -q) with minimal distance (or, equivalently, maximal dot product) + between two consecutive frames. + + Expects a tensor of shape (L, J, 4), where L is the sequence length and J is the number of joints. + Returns a tensor of the same shape. + """ + assert len(q.shape) == 3 + assert q.shape[-1] == 4 + + result = q.clone() + dot_products = torch.sum(q[1:] * q[:-1], dim=2) + mask = dot_products < 0 + mask = (torch.cumsum(mask, dim=0) % 2).type(torch.bool) + result[1:][mask] *= -1 + return result + + @staticmethod + def qfix_positive_(q: torch.Tensor) -> torch.Tensor: + """ + Enforce quaternion w to be positive + + Expects a tensor of shape (L, J, 4), where L is the sequence length and J is the number of joints. + Returns a tensor of the same shape. + """ + assert q.shape[-1] == 4 + + mask = q[..., 0] < 0. + q_out = q.clone() + q_out[mask] *= -1. + return q_out + + @staticmethod + def weighted_mean_(q: torch.Tensor, w: torch.Tensor) -> torch.Tensor: + """ + Computes weighted mean of multiple quaternions. + + Expects a tensor of shape (*, N, 4), where N is the number of quaternions and a weight tensor of the same + or broadcastable shape. Weight have to sum up to 1. + + Returns a tensor of the same shape. + """ + + Q = (q * w) + QQT = Q.transpose(-1, -2) @ Q + # There is a bug calculating the eigenvectors on GPU in torch 1.8 + # TODO change on 1.9 (torch.linalg.eig) + mean_q_unnorm = torch.symeig(QQT.to('cpu'), eigenvectors=True)[1][..., -1].to(q.device) + mean_q = torch.nn.functional.normalize(mean_q_unnorm, dim=-1) + return mean_q \ No newline at end of file From 1fde5b42f76f64b0c03f1bfdf6c62c2985e98a8f Mon Sep 17 00:00:00 2001 From: Michail Kaseris Date: Tue, 16 Jan 2024 16:43:12 +0200 Subject: [PATCH 3/3] Preprocessing script for amass dataset --- tools/preprocess/preprocess_amass.py | 82 ++++++++++++++++++++++++++++ 1 file changed, 82 insertions(+) create mode 100644 tools/preprocess/preprocess_amass.py diff --git a/tools/preprocess/preprocess_amass.py b/tools/preprocess/preprocess_amass.py new file mode 100644 index 0000000..60641c1 --- /dev/null +++ b/tools/preprocess/preprocess_amass.py @@ -0,0 +1,82 @@ +import argparse +import logging +import os +import tarfile +from io import BytesIO + +import numpy as np +import torch +import zarr + +from skelcast.utils.quaternion import Quaternion + +def process_data(dataset_dir, out, target_fps=20): + z_poses = zarr.open(os.path.join(out, 'poses.zarr'), mode='w', shape=(0, 22, 4), chunks=(1000, 22, 4), dtype=np.float32) + z_trans = zarr.open(os.path.join(out, 'trans.zarr'), mode='w', shape=(0, 3), chunks=(1000, 3), dtype=np.float32) + z_index = zarr.open(os.path.join(out, 'poses_index.zarr'), mode='w', shape=(0, 2), chunks=(1000, 2), dtype=int) + i = 0 + tar = tarfile.open(dataset_dir, 'r') + for member in tar: + file_name = os.path.basename(member.name) + if file_name.endswith('.npz') and not file_name.startswith('.'): + try: + with tar.extractfile(member) as f: + array_file = BytesIO() + array_file.write(f.read()) + array_file.seek(0) + data = np.load(array_file) + + frame_rate = data['mocap_framerate'] + + if not frame_rate % target_fps == 0.: + logging.warning(f'FPS does not match for dataset: {dataset_dir}') + frame_multiplier = int(np.round(frame_rate / target_fps)) + + body_pose = data['poses'][::frame_multiplier, 0:66].reshape(-1, 22, 3) + + body_trans = data['trans'][::frame_multiplier] + + t_body_pose = torch.tensor(body_pose) + + q_body_pose = Quaternion.qfix_(Quaternion(angle=t_body_pose.norm(dim=-1), axis=t_body_pose).q).numpy() + + z_poses.append(q_body_pose, axis=0) + z_trans.append(body_trans, axis=0) + z_index.append(np.array([[i, i + q_body_pose.shape[0]]]), axis=0) + i = i + q_body_pose.shape[0] + except Exception as e: + logging.error(e) + + +parser = argparse.ArgumentParser(description='AMASS Process Raw Data') + +parser.add_argument('path', + type=str, + help='Path of the tar files') + +parser.add_argument('out', + type=str, + help='The output path') + +parser.add_argument('fps', + type=int, + default=20, + help='FPS') + +parser.add_argument('--datasets', + type=str, + nargs="+", + help='The names of the datasets to process', + default=None) + +args = parser.parse_args() + + +if __name__ == '__main__': + in_path = args.path + out_path = args.out + fps = args.fps + datasets = args.datasets + for dataset in datasets: + print(f"Processing {dataset}...") + process_data(os.path.join(in_path, dataset + '.tar.bz2'), os.path.join(out_path, dataset), target_fps=fps)