Skip to content

Commit

Permalink
Merge pull request #31 from wearable-motion-capture/refactor_estimators
Browse files Browse the repository at this point in the history
Refactor estimators
  • Loading branch information
faweigend authored Jun 11, 2024
2 parents 1725c1a + d8ab6fa commit c02bf0a
Show file tree
Hide file tree
Showing 7 changed files with 92 additions and 45 deletions.
58 changes: 21 additions & 37 deletions experimental_applications/watch_phone_pocket_lstm.py
Original file line number Diff line number Diff line change
@@ -1,55 +1,42 @@
import argparse
import atexit
import logging
import queue
import signal
import threading

from wear_mocap_ape import config
from wear_mocap_ape.data_deploy.nn import deploy_models
from wear_mocap_ape.data_types import messaging
from wear_mocap_ape.stream.listener.imu import ImuListener
from wear_mocap_ape.stream.publisher.watch_phone_pocket_nn_udp import WatchPhonePocketNnUDP
from wear_mocap_ape.estimate.watch_phone_pocket_nn import WatchPhonePocketNN
from wear_mocap_ape.stream.publisher.imu_udp import IMUPublisherUDP


def run_watch_phone_pocket_nn_udp(ip: str, smooth: int, stream_mc: bool) -> WatchPhonePocketNnUDP:
# data for left-hand mode
q = queue.Queue()

def run_watch_phone_pocket_nn_udp(ip: str, smooth: int) -> WatchPhonePocketNN:
# listen for imu data from phone and watch
imu_listener = ImuListener(
ip=ip,
msg_size=messaging.watch_phone_imu_msg_len,
port=config.PORT_LISTEN_WATCH_PHONE_IMU
)
imu_thread = threading.Thread(
target=imu_listener.listen,
args=(q,)
)
sensor_q = imu_listener.listen_in_thread()

# process into arm pose and body orientation
estimator = WatchPhonePocketNnUDP(ip=ip,
port=config.PORT_PUB_LEFT_ARM,
smooth=smooth,
model_hash=deploy_models.LSTM.WATCH_PHONE_POCKET.value,
stream_mc=stream_mc,
mc_samples=60)
udp_thread = threading.Thread(
target=estimator.processing_loop,
args=(q,)
estimator = WatchPhonePocketNN(smooth=smooth,
model_hash=deploy_models.LSTM.WATCH_PHONE_POCKET.value,
add_mc_samples=True,
monte_carlo_samples=60)
msg_q = estimator.process_in_thread(sensor_q)

# the publisher publishes pose estimates from the queue via UDP
pub = IMUPublisherUDP(
ip=ip,
port=config.PORT_PUB_LEFT_ARM
)
pub.publish_in_thread(msg_q)

imu_thread.start()
udp_thread.start()

def terminate_all(*args):
imu_listener.terminate()
estimator.terminate()

# make sure all handler exit on termination
atexit.register(terminate_all)
signal.signal(signal.SIGTERM, terminate_all)
signal.signal(signal.SIGINT, terminate_all)
# wait for any key to end the threads
input("[TERMINATION TRIGGER] press enter to exit")
imu_listener.terminate()
estimator.terminate()
pub.terminate()

return estimator

Expand All @@ -63,15 +50,12 @@ def terminate_all(*args):
# Required IP argument
parser.add_argument('ip', type=str, help=f'put your local IP here.')
parser.add_argument('smooth', nargs='?', type=int, default=5, help=f'smooth predicted trajectories')
parser.add_argument('--stream_mc', action='store_true')
parser.add_argument('--no-stream_mc', dest='stream_mc', action='store_false')
parser.set_defaults(stream_mc=True)

args = parser.parse_args()

ip_arg = args.ip
smooth_arg = args.smooth
stream_mc_arg = args.stream_mc

# run the predictions
run_watch_phone_pocket_nn_udp(ip_arg, smooth_arg, stream_mc_arg)
run_watch_phone_pocket_nn_udp(ip_arg, smooth_arg)
Binary file not shown.
Binary file not shown.
Binary file not shown.
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
{
"model": "DropoutLSTM",
"hidden_layer_count": 2,
"hidden_layer_size": 256,
"epochs": 200,
"batch_size": 128,
"learning_rate": 0.0015,
"dropout": 0.2,
"sequence_len": 6,
"normalize": true,
"seq_overlap": true,
"create_norm_stats": true,
"early_stopping": 10,
"hash": "670b66fa7664252d1cfb3b5a8a362002ffeeba5c",
"y_targets_n": "ORI_CAL_LARM_UARM_HIPS",
"x_inputs_n": "WATCH_PHONE_CAL_HIP",
"y_targets_v": [
"gt_larm_6drr_cal_1",
"gt_larm_6drr_cal_2",
"gt_larm_6drr_cal_3",
"gt_larm_6drr_cal_4",
"gt_larm_6drr_cal_5",
"gt_larm_6drr_cal_6",
"gt_uarm_6drr_cal_1",
"gt_uarm_6drr_cal_2",
"gt_uarm_6drr_cal_3",
"gt_uarm_6drr_cal_4",
"gt_uarm_6drr_cal_5",
"gt_uarm_6drr_cal_6",
"gt_hips_yrot_cal_sin",
"gt_hips_yrot_cal_cos"
],
"x_inputs_v": [
"sw_dt",
"sw_gyro_x",
"sw_gyro_y",
"sw_gyro_z",
"sw_lvel_x",
"sw_lvel_y",
"sw_lvel_z",
"sw_lacc_x",
"sw_lacc_y",
"sw_lacc_z",
"sw_grav_x",
"sw_grav_y",
"sw_grav_z",
"sw_6drr_cal_1",
"sw_6drr_cal_2",
"sw_6drr_cal_3",
"sw_6drr_cal_4",
"sw_6drr_cal_5",
"sw_6drr_cal_6",
"sw_pres_cal",
"ph_hips_yrot_cal_sin",
"ph_hips_yrot_cal_cos"
],
"datetime": "2024-06-11 11:18:45",
"Loss/train": 0.1802734130402406,
"Loss/test": 0.21087305423568684,
"Loss/b_test": 0.20764025849280615,
"MAE/Hand": 9.766056275381166,
"MAE/Elbow": 8.242187291172527,
"RMSE/Hand": 11.265157274595596,
"RMSE/Elbow": 9.843276029945384
}
2 changes: 1 addition & 1 deletion src/wear_mocap_ape/data_deploy/nn/deploy_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,6 @@


class LSTM(Enum):
WATCH_PHONE_POCKET = "7ffa47dfbf2e1a16057e0253a1c2ba01e465d55a"
WATCH_PHONE_POCKET = "670b66fa7664252d1cfb3b5a8a362002ffeeba5c"
WATCH_PHONE_UARM = "7cb5cdf94ef4c66388c7f15f642005d5e008146a"
WATCH_ONLY = "04f4ad63bfccb3668f7598c9375403e10b1fae2a"
12 changes: 5 additions & 7 deletions src/wear_mocap_ape/estimate/watch_phone_pocket_nn.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
from abc import abstractmethod

import numpy as np
import torch

Expand Down Expand Up @@ -27,11 +25,11 @@ def __init__(self,
self.__slp = messaging.WATCH_PHONE_IMU_LOOKUP

# load model from given parameters
self.__nn_model, params = models.load_deployed_model_from_hash(hash_str=model_hash)
self.__nn_model, params = nn_models.load_deployed_model_from_hash(hash_str=model_hash)

super().__init__(
x_inputs=NNS_INPUTS(params["x_inputs_v"]),
y_targets=NNS_TARGETS(params["y_targets_v"]),
x_inputs=NNS_INPUTS[params["x_inputs_n"]],
y_targets=NNS_TARGETS[params["y_targets_n"]],
smooth=smooth,
normalize=params["normalize"],
seq_len=params["sequence_len"],
Expand Down Expand Up @@ -99,7 +97,7 @@ def parse_row_to_xx(self, row):

def make_prediction_from_row_hist(self, xx):
# cast to a torch tensor with batch size 1
xx = torch.tensor(xx[None, :, :])
xx = torch.tensor(xx[None, :, :], dtype=torch.float32)
with torch.no_grad():
# make mote carlo predictions if the model makes use of dropout
t_preds = self.__nn_model.monte_carlo_predictions(x=xx, n_samples=self.__mc_samples)
Expand All @@ -111,4 +109,4 @@ def make_prediction_from_row_hist(self, xx):

# we are only interested in the last prediction of the sequence
t_preds = t_preds[:, -1, :]
return t_preds
return t_preds

0 comments on commit c02bf0a

Please sign in to comment.