Skip to content

Commit

Permalink
Merge pull request #44 from kaseris/dev
Browse files Browse the repository at this point in the history
Dev
  • Loading branch information
kaseris committed Dec 7, 2023
2 parents 58209ea + 25e45b9 commit 46fab5c
Show file tree
Hide file tree
Showing 5 changed files with 217 additions and 6 deletions.
30 changes: 24 additions & 6 deletions src/skelcast/data/dataset.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import os
import logging
import pickle
from dataclasses import dataclass
from typing import Any, Tuple, List

Expand Down Expand Up @@ -172,7 +174,8 @@ def __init__(
max_number_of_bodies: int = 4,
max_duration: int = 300,
n_joints: int = 25,
transforms: Any = None
transforms: Any = None,
cache_file: str = None,
) -> None:
self.data_directory = data_directory
self.missing_files_dir = missing_files_dir
Expand All @@ -191,11 +194,21 @@ def __init__(
missing_skeleton_names=missing_files, skeleton_files=self.skeleton_files
)
self.skeleton_files_clean = []
for fname in self.skeleton_files:
if should_blacklist(fname):
continue
else:
self.skeleton_files_clean.append(fname)

if cache_file is None:
for fname in self.skeleton_files:
if should_blacklist(fname):
continue
else:
self.skeleton_files_clean.append(fname)
else:
# Check if cache file exists and then unpickle it and store its data to self.skeleton_files_clean
if os.path.exists(cache_file):
# log that we are loading the cache file
logging.info(f"Loading cache file {cache_file}...")
with open(cache_file, 'rb') as f:
self.skeleton_files_clean = pickle.load(f)


def load_labels(self):
with open(self.labels_file, 'r') as f:
Expand Down Expand Up @@ -236,3 +249,8 @@ def __getitem__(self, index) -> torch.Tensor:

def __len__(self):
return len(self.skeleton_files_clean)

def store_to_cache(self, cache_file: str) -> None:
with open(cache_file, 'wb') as f:
pickle.dump(self.skeleton_files_clean, f)
logging.info(f"Stored {len(self.skeleton_files_clean)} files to cache file {cache_file}.")
Empty file.
96 changes: 96 additions & 0 deletions src/skelcast/primitives/skeleton.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
from enum import IntEnum


class KinectSkeleton(IntEnum):
"""
Enum that represents the Kinect's skeleton joints and their indices.
"""
SPINEBASE = 0
SPINEMID = 1
NECK = 2
HEAD = 3
SHOULDERLEFT = 4
ELBOWLEFT = 5
WRISTLEFT = 6
HANDLEFT = 7
SHOULDERRIGHT = 8
ELBOWRIGHT = 9
WRISTRIGHT = 10
HANDRIGHT = 11
HIPLEFT = 12
KNEELEFT = 13
ANKLELEFT = 14
FOOTLEFT = 15
HIPRIGHT = 16
KNEERIGHT = 17
ANKLERIGHT = 18
FOOTRIGHT = 19
SPINESHOULDER = 20
HANDTIPLEFT = 21
THUMBLEFT = 22
HANDTIPRIGHT = 23
THUMBRIGHT = 24

def connections():
"""
Returns a list of tuples that represent the connections between joints.
Connections:
---
>>> (0, 1), # SPINEBASE to SPINEMID
>>> (1, 20), # SPINEMID to SPINESHOULDER
>>> (20, 2), # SPINESHOULDER to NECK
>>> (2, 3), # NECK to HEAD
>>> (20, 4), # SPINESHOULDER to SHOULDERLEFT
>>> (4, 5), # SHOULDERLEFT to ELBOWLEFT
>>> (5, 6), # ELBOWLEFT to WRISTLEFT
>>> (6, 7), # WRISTLEFT to HANDLEFT
>>> (7, 22), # HANDLEFT to THUMBLEFT
>>> (7, 21), # HANDLEFT to HANDTIPLEFT
>>> (20, 8), # SPINESHOULDER to SHOULDERRIGHT
>>> (8, 9), # SHOULDERRIGHT to ELBOWRIGHT
>>> (9, 10), # ELBOWRIGHT to WRISTRIGHT
>>> (10, 11),# WRISTRIGHT to HANDRIGHT
>>> (11, 24),# HANDRIGHT to THUMBRIGHT
>>> (11, 23),# HANDRIGHT to HANDTIPRIGHT
>>> (0, 12), # SPINEBASE to HIPLEFT
>>> (12, 13),# HIPLEFT to KNEELEFT
>>> (13, 14),# KNEELEFT to ANKLELEFT
>>> (14, 15),# ANKLELEFT to FOOTLEFT
>>> (0, 16), # SPINEBASE to HIPRIGHT
>>> (16, 17),# HIPRIGHT to KNEERIGHT
>>> (17, 18),# KNEERIGHT to ANKLERIGHT
>>> (18, 19),# ANKLERIGHT to FOOTRIGHT
Returns:
---
- connections (list): A list of tuples that represent the connections between joints.
"""
return [
(KinectSkeleton.SPINEBASE, KinectSkeleton.SPINEMID),
(KinectSkeleton.SPINEMID, KinectSkeleton.SPINESHOULDER),
(KinectSkeleton.SPINESHOULDER, KinectSkeleton.NECK),
(KinectSkeleton.NECK, KinectSkeleton.HEAD),
(KinectSkeleton.SPINESHOULDER, KinectSkeleton.SHOULDERLEFT),
(KinectSkeleton.SHOULDERLEFT, KinectSkeleton.ELBOWLEFT),
(KinectSkeleton.ELBOWLEFT, KinectSkeleton.WRISTLEFT),
(KinectSkeleton.WRISTLEFT, KinectSkeleton.HANDLEFT),
(KinectSkeleton.HANDLEFT, KinectSkeleton.THUMBLEFT),
(KinectSkeleton.HANDLEFT, KinectSkeleton.HANDTIPLEFT),
(KinectSkeleton.SPINESHOULDER, KinectSkeleton.SHOULDERRIGHT),
(KinectSkeleton.SHOULDERRIGHT, KinectSkeleton.ELBOWRIGHT),
(KinectSkeleton.ELBOWRIGHT, KinectSkeleton.WRISTRIGHT),
(KinectSkeleton.WRISTRIGHT, KinectSkeleton.HANDRIGHT),
(KinectSkeleton.HANDRIGHT, KinectSkeleton.THUMBRIGHT),
(KinectSkeleton.HANDRIGHT, KinectSkeleton.HANDTIPRIGHT),
(KinectSkeleton.SPINEBASE, KinectSkeleton.HIPLEFT),
(KinectSkeleton.HIPLEFT, KinectSkeleton.KNEELEFT),
(KinectSkeleton.KNEELEFT, KinectSkeleton.ANKLELEFT),
(KinectSkeleton.ANKLELEFT, KinectSkeleton.FOOTLEFT),
(KinectSkeleton.SPINEBASE, KinectSkeleton.HIPRIGHT),
(KinectSkeleton.HIPRIGHT, KinectSkeleton.KNEERIGHT),
(KinectSkeleton.KNEERIGHT, KinectSkeleton.ANKLERIGHT),
(KinectSkeleton.ANKLERIGHT, KinectSkeleton.FOOTRIGHT),
]
72 changes: 72 additions & 0 deletions src/skelcast/primitives/visualize.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
import time

from enum import Enum
from typing import Union

import numpy as np
import torch
import open3d as o3d

from skelcast.primitives.skeleton import KinectSkeleton


class Colors(Enum):
"""
Enum that represents the colors used for visualizing the skeleton.
"""
RED = [1, 0, 0]
GREEN = [0, 1, 0]
BLUE = [0, 0, 1]
YELLOW = [1, 1, 0]
CYAN = [0, 1, 1]
MAGENTA = [1, 0, 1]
WHITE = [1, 1, 1]
BLACK = [0, 0, 0]



def visualize_skeleton(skeleton: Union[np.ndarray, torch.Tensor],
framerate: int = 30,
skeleton_type: str = 'kinect'):
assert isinstance(skeleton, (np.ndarray, torch.Tensor)), f'Expected a numpy array or a PyTorch tensor, got {type(skeleton)} instead.'
# We assume that the skeleton movement has a shape of (seq_len, n_joints, 3)
if isinstance(skeleton, torch.Tensor):
skeleton = skeleton.to(torch.float64).numpy()
assert len(skeleton.shape) == 3, f'Expected a 3-dimensional array, got {len(skeleton.shape)} dimensions instead.'
assert skeleton.shape[2] == 3, f'Expected the last dimension to be 3, got {skeleton.shape[2]} instead.'
seq_len, n_joints, _ = skeleton.shape
if skeleton_type == 'kinect':
assert n_joints == 25, f'Expected the second dimension to be 25, got {n_joints} instead.'
connections = KinectSkeleton.connections()

# Create a point cloud object and a line set object
# These serve as containers for the skeleton data and the connections between joints
point_cloud = o3d.geometry.PointCloud()
line_set = o3d.geometry.LineSet()

# Create a visualization window
vis = o3d.visualization.Visualizer()
vis.create_window()

for timestep in range(seq_len):
# Update point cloud for the current timestep
point_cloud.points = o3d.utility.Vector3dVector(skeleton[timestep])
point_cloud.paint_uniform_color(Colors.RED.value) # Red color for joints

bone_lines = [[i.value, j.value] for i, j in connections]
line_set.lines = o3d.utility.Vector2iVector(bone_lines)
line_set.points = o3d.utility.Vector3dVector(skeleton[timestep])
line_set.colors = o3d.utility.Vector3dVector([Colors.BLUE.value for _ in connections]) # Blue color for connections

if timestep == 0:
vis.add_geometry(point_cloud)
vis.add_geometry(line_set)
else:
vis.update_geometry(point_cloud)
vis.update_geometry(line_set)

vis.poll_events()
vis.update_renderer()
time.sleep(1.0 / framerate)

vis.destroy_window()
25 changes: 25 additions & 0 deletions tools/visualize_skel_movement.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
import argparse
import logging

from skelcast.data.dataset import NTURGBDDataset
from skelcast.primitives.visualize import visualize_skeleton

argparser = argparse.ArgumentParser(description='Visualize skeleton movement.')
argparser.add_argument('--dataset', type=str, required=True, help='Path to the dataset.')
argparser.add_argument('--sample', type=int, required=True, help='Sample index to visualize.')
argparser.add_argument('--cache-file', type=str, required=False, help='Path to the cache file.')

args = argparser.parse_args()


if __name__ == '__main__':
log_format = '[%(asctime)s] %(levelname)s: %(message)s'
date_format = '%Y-%m-%d %H:%M:%S'
logging.basicConfig(level=logging.INFO, format=log_format, datefmt=date_format)

dataset = NTURGBDDataset(args.dataset, missing_files_dir='data/missing/', label_file='data/labels.txt',
cache_file=args.cache_file,
max_number_of_bodies=1)
skeleton, label = dataset[args.sample]
logging.info(f'Label: {label}')
visualize_skeleton(skeleton.squeeze(1))

0 comments on commit 46fab5c

Please sign in to comment.