Skip to content

Commit

Permalink
move kalman model to APE library
Browse files Browse the repository at this point in the history
  • Loading branch information
faweigend committed Oct 9, 2023
1 parent 62ec260 commit 16c8437
Show file tree
Hide file tree
Showing 6 changed files with 763 additions and 14 deletions.
18 changes: 9 additions & 9 deletions example_scripts/watch_phone_pocket_stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,9 @@
import threading

from wear_mocap_ape import config
from wear_mocap_ape.data_deploy.nn import deploy_models
from wear_mocap_ape.data_types import messaging
from wear_mocap_ape.estimate.phone_pocket_free_hips_udp import FreeHipsPocketUDP
from wear_mocap_ape.stream.listener.imu import ImuListener
from wear_mocap_ape.stream.publisher.kalman_pocket_phone_udp import KalmanPhonePocket

if __name__ == "__main__":
logging.basicConfig(level=logging.INFO)
Expand Down Expand Up @@ -42,15 +41,16 @@
)

# process into arm pose and body orientation
fhp = FreeHipsPocketUDP(ip=ip_arg,
model_hash=deploy_models.LSTM.WATCH_PHONE_POCKET.value,
kpp = KalmanPhonePocket(ip=ip_arg,
smooth=smooth_arg,
num_ensemble=48,
port=config.PORT_PUB_LEFT_ARM,
monte_carlo_samples=25,
stream_monte_carlo=stream_mc_arg)
window_size=10,
stream_mc=stream_mc_arg,
model_name="SW-model-sept-4")
p_thread = threading.Thread(
target=fhp.stream_loop,
args=(left_q,)
target=kpp.stream_wearable_devices,
args=(left_q, True,)
)

l_thread.start()
Expand All @@ -59,7 +59,7 @@

def terminate_all(*args):
imu_l.terminate()
fhp.terminate()
kpp.terminate()


# make sure all handler exit on termination
Expand Down
4 changes: 3 additions & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[metadata]
name = wear_mocap_ape
version = 1.0.2
version = 1.1.0
author = Fabian Weigend
author_email = fweigend@asu.edu
description =
Expand All @@ -25,6 +25,8 @@ install_requires =
pandas
matplotlib
pynput
einops
bayesian_torch

packages = find:
include_package_data = True
Expand Down
8 changes: 4 additions & 4 deletions src/wear_mocap_ape/config.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
import os
from pathlib import Path

proj_path = os.path.dirname(os.path.abspath(__file__))
proj_path = Path(__file__).parent.absolute()

PATHS = {
"deploy": f"{proj_path}/data_deploy/",
"skeleton": f"{proj_path}/data_deploy/"
"deploy": proj_path / "data_deploy",
"skeleton": proj_path / "data_deploy"
}

# ports for publishing to other services
Expand Down
Binary file not shown.
215 changes: 215 additions & 0 deletions src/wear_mocap_ape/estimate/kalman_models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,215 @@
from bayesian_torch.layers.flipout_layers.linear_flipout import LinearFlipout
from torch.distributions.multivariate_normal import MultivariateNormal
from einops import rearrange, repeat
import torch
import numpy as np
import torch.nn as nn
import torch.nn.functional as F

class utils:
def __init__(self, num_ensemble, dim_x, dim_z):
self.num_ensemble = num_ensemble
self.dim_x = dim_x
self.dim_z = dim_z

def multivariate_normal_sampler(self, mean, cov, k):
sampler = MultivariateNormal(mean, cov)
return sampler.sample((k,))

def format_state(self, state):
state = repeat(state, "k dim -> n k dim", n=self.num_ensemble)
state = rearrange(state, "n k dim -> (n k) dim")
cov = torch.eye(self.dim_x) * 0.1
init_dist = self.multivariate_normal_sampler(
torch.zeros(self.dim_x), cov, self.num_ensemble
)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
init_dist = init_dist.to(device)
state = state + init_dist
state = state.to(dtype=torch.float32)
return state



class Seq_MLP_process_model(nn.Module):
def __init__(self, num_ensemble, dim_x, win_size, dim_model, num_heads):
super(Seq_MLP_process_model, self).__init__()
self.num_ensemble = num_ensemble
self.dim_x = dim_x
self.dim_model = dim_model
self.num_heads = num_heads
self.win_size = win_size

self.bayes1 = LinearFlipout(in_features=self.dim_x * win_size, out_features=256)
self.bayes3 = LinearFlipout(in_features=256, out_features=512)
self.bayes_m2 = torch.nn.Linear(512, self.dim_x)

def forward(self, input):
batch_size = input.shape[0]
input = rearrange(input, "n en k dim -> (n en) (k dim)")
# branch of the state
x, _ = self.bayes1(input)
x = F.leaky_relu(x)
x, _ = self.bayes3(x)
x = F.leaky_relu(x)
x = self.bayes_m2(x)
output = rearrange(x, "(bs en) dim -> bs en dim", en=self.num_ensemble)
return output


class NewObservationNoise(nn.Module):
def __init__(self, dim_z, r_diag):
"""
observation noise model is used to learn the observation noise covariance matrix
R from the learned observation, kalman filter require a explicit matrix for R
therefore we construct the diag of R to model the noise here
input -> [batch_size, 1, encoding/dim_z]
output -> [batch_size, dim_z, dim_z]
"""
super(NewObservationNoise, self).__init__()
self.dim_z = dim_z
self.r_diag = r_diag

self.fc1 = nn.Linear(self.dim_z, 32)
self.fc2 = nn.Linear(32, self.dim_z)

def forward(self, inputs):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
batch_size = inputs.shape[0]
constant = np.ones(self.dim_z) * 1e-3
init = np.sqrt(np.square(self.r_diag) - constant)
diag = self.fc1(inputs)
diag = F.relu(diag)
diag = self.fc2(diag)
diag = torch.square(diag + torch.Tensor(constant).to(device)) + torch.Tensor(
init
).to(device)
diag = rearrange(diag, "bs k dim -> (bs k) dim")
R = torch.diag_embed(diag)
return R


class SeqSensorModel(nn.Module):
"""
the sensor model takes the current raw sensor (usually high-dimensional images)
and map the raw sensor to low-dimension
Many advanced model architecture can be explored here, i.e., Vision transformer, FlowNet,
RAFT, and ResNet families, etc.
input -> [batch_size, 1, win, raw_input]
output -> [batch_size, num_ensemble, dim_z]
"""

def __init__(self, num_ensemble, dim_z, win_size, input_size_1):
super(SeqSensorModel, self).__init__()
self.dim_z = dim_z
self.num_ensemble = num_ensemble

self.fc2 = nn.Linear(input_size_1 * win_size, 256)
self.fc3 = LinearFlipout(256, 256)
self.fc5 = LinearFlipout(256, 64)
self.fc6 = LinearFlipout(64, self.dim_z)

def forward(self, x):
batch_size = x.shape[0]
x = rearrange(x, "bs k en dim -> bs (k en dim)")
x = repeat(x, "bs dim -> bs k dim", k=self.num_ensemble)
x = rearrange(x, "bs k dim -> (bs k) dim")

x = self.fc2(x)
x = F.leaky_relu(x)
x, _ = self.fc3(x)
x = F.leaky_relu(x)
x, _ = self.fc5(x)
x = F.leaky_relu(x)
encoding = x
obs, _ = self.fc6(x)
obs = rearrange(
obs, "(bs k) dim -> bs k dim", bs=batch_size, k=self.num_ensemble
)
obs_z = torch.mean(obs, axis=1)
obs_z = rearrange(obs_z, "bs (k dim) -> bs k dim", k=1)
encoding = rearrange(
encoding, "(bs k) dim -> bs k dim", bs=batch_size, k=self.num_ensemble
)
encoding = torch.mean(encoding, axis=1)
encoding = rearrange(encoding, "(bs k) dim -> bs k dim", bs=batch_size, k=1)
return obs, obs_z, encoding




class new_smartwatch_model(nn.Module):
def __init__(self, num_ensemble, win_size, dim_x, dim_z, input_size_1):
super(new_smartwatch_model, self).__init__()
self.num_ensemble = num_ensemble
self.dim_x = dim_x
self.dim_z = dim_z
self.win_size = win_size
self.r_diag = np.ones((self.dim_z)).astype(np.float32) * 0.05
self.r_diag = self.r_diag.astype(np.float32)

# instantiate model
self.process_model = Seq_MLP_process_model(
self.num_ensemble, self.dim_x, self.win_size, 256, 8
)
self.sensor_model = SeqSensorModel(
self.num_ensemble, self.dim_z, win_size, input_size_1
)
self.observation_noise = NewObservationNoise(self.dim_z, self.r_diag)

def forward(self, inputs, states):
# decompose inputs and states
batch_size = inputs[0].shape[0]
raw_obs = inputs
state_old = states

##### prediction step #####
state_pred = self.process_model(state_old)
m_A = torch.mean(state_pred, axis=1) # m_A -> [bs, dim_x]

# zero mean
mean_A = repeat(m_A, "bs dim -> bs k dim", k=self.num_ensemble)
A = state_pred - mean_A
A = rearrange(A, "bs k dim -> bs dim k")

##### update step #####

# since observation model is identity function
H_X = state_pred
mean = torch.mean(H_X, axis=1)
H_X_mean = rearrange(mean, "bs (k dim) -> bs k dim", k=1)
m = repeat(mean, "bs dim -> bs k dim", k=self.num_ensemble)
H_A = H_X - m
# transpose operation
H_XT = rearrange(H_X, "bs k dim -> bs dim k")
H_AT = rearrange(H_A, "bs k dim -> bs dim k")

# get learned observation
ensemble_z, z, encoding = self.sensor_model(raw_obs)

# measurement update
y = rearrange(ensemble_z, "bs k dim -> bs dim k")
R = self.observation_noise(z)

innovation = (1 / (self.num_ensemble - 1)) * torch.matmul(H_AT, H_A) + R
inv_innovation = torch.linalg.inv(innovation)
K = (1 / (self.num_ensemble - 1)) * torch.matmul(
torch.matmul(A, H_A), inv_innovation
)

gain = rearrange(torch.matmul(K, y - H_XT), "bs dim k -> bs k dim")
state_new = state_pred + gain

# gather output
m_state_new = torch.mean(state_new, axis=1)
m_state_new = rearrange(m_state_new, "bs (k dim) -> bs k dim", k=1)
m_state_pred = rearrange(m_A, "bs (k dim) -> bs k dim", k=1)
output = (
state_new.to(dtype=torch.float32),
m_state_new.to(dtype=torch.float32),
m_state_pred.to(dtype=torch.float32),
z.to(dtype=torch.float32),
ensemble_z.to(dtype=torch.float32),
)
return output
Loading

0 comments on commit 16c8437

Please sign in to comment.