-
Notifications
You must be signed in to change notification settings - Fork 17
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Motion tracking demo using bokeh for live visualization (#71)
* motion tracking demo * use full imports and allow for blocking mode for tests * add recording * Update process.py * Update process.py * minor changes from review * minor changes * cleaning up lfs messup * add readme * add readme * minor changes * added more infos * Rename Readme.md to README.md --------- Co-authored-by: SveaMeyer13 <svea.meyer@tum.de>
- Loading branch information
1 parent
bd7f344
commit d5be705
Showing
9 changed files
with
705 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
*.aedat4 filter=lfs diff=lfs merge=lfs -text | ||
*.pickle filter=lfs diff=lfs merge=lfs -text |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,42 @@ | ||
# Dynamic Neural Fields - Demos | ||
|
||
This readme assumes that you have installed lava, lava-loihi and lava-dnf in the same virtual environment. Additionally, you should check if you downloaded the input recording which is stored using git lfs. Check if the file size is reasonable. | ||
|
||
```bash | ||
lava-dnf/demos/motion_tracking$ ls -l | ||
-rw-r--r-- 1 <user> <group> 93353870 Sep 18 08:01 dvs_recording.aedat4 | ||
``` | ||
If the file size is only 133, then use the following command before running: | ||
```bash | ||
lava-dnf/demos/motion_tracking$ git lfs pull | ||
``` | ||
|
||
### Running the demos | ||
The demo will run in your browser via port-forwarding. Choose a random port_num between 10000 and 20000. | ||
(This is to avoid that multiple users try to use the same port) | ||
|
||
#### Connect to external vlab with port-forwarding | ||
```bash | ||
ssh <my-vm>.research.intel-research.net -L 127.0.0.1:<port_num>:127.0.0.1:<port_num> | ||
``` | ||
|
||
#### Activate your virtual environment | ||
Location of your virtual enviornment might differ. | ||
```bash | ||
source lava/lava_nx_env/bin/activate | ||
``` | ||
|
||
#### Navigate to the motion_tracking demo: | ||
```bash | ||
cd lava-dnf/demos/motion_tracking | ||
``` | ||
#### Start the bokeh app | ||
```bash | ||
bokeh serve main_motion_tracking.py --port <port_num> | ||
``` | ||
|
||
open your browser and type: | ||
http://localhost:<port_num>/main_motion_tracking | ||
|
||
As the network is pre-compiled, the demo will appear immediately, and you just need to click the "run" button to start the demo. | ||
It is currently not possible to interrupt the demo while it is running. Please wait for the demo to terminate and click the "close" button for all processes to terminate gracefully. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,149 @@ | ||
# Copyright (C) 2023 Intel Corporation | ||
# SPDX-License-Identifier: BSD-3-Clause | ||
# See: https://spdx.org/licenses/ | ||
|
||
import numpy as np | ||
from numpy.lib.stride_tricks import as_strided | ||
from scipy import signal | ||
try: | ||
from dv import AedatFile | ||
except ModuleNotFoundError: | ||
print("Module 'dv' is not installed. Please install module 'dv' in order" | ||
" to use the process DVSFileInput.") | ||
exit() | ||
|
||
from lava.magma.core.process.process import AbstractProcess | ||
from lava.magma.core.process.ports.ports import OutPort | ||
from lava.magma.core.sync.protocols.loihi_protocol import LoihiProtocol | ||
from lava.magma.core.model.py.type import LavaPyType | ||
from lava.magma.core.model.py.ports import PyOutPort | ||
from lava.magma.core.resources import CPU | ||
from lava.magma.core.model.py.model import PyLoihiProcessModel | ||
from lava.magma.core.decorator import implements, requires | ||
|
||
|
||
|
||
class DVSFileInput(AbstractProcess): | ||
"""Process to read from .aedat4 file, downsample the frame in different | ||
modes (max pooling, convolution, normal downsampling) and send out the | ||
down sampled event frame.""" | ||
def __init__(self, | ||
true_height: int, | ||
true_width: int, | ||
file_path: str, | ||
flatten: bool = False, | ||
down_sample_factor: int = 1, | ||
down_sample_mode: str = "down_sample", | ||
num_steps=1) -> None: | ||
super().__init__(true_height=true_height, | ||
true_width=true_width, | ||
file_path=file_path, | ||
flatten=flatten, | ||
down_sample_factor=down_sample_factor, | ||
down_sample_mode=down_sample_mode, | ||
num_steps=num_steps) | ||
|
||
down_sampled_height = true_height // down_sample_factor | ||
down_sampled_width = true_width // down_sample_factor | ||
|
||
if flatten: | ||
out_shape = (down_sampled_width * down_sampled_height,) | ||
else: | ||
out_shape = (down_sampled_width, down_sampled_height) | ||
self.event_frame_out = OutPort(shape=out_shape) | ||
|
||
|
||
@implements(proc=DVSFileInput, protocol=LoihiProtocol) | ||
@requires(CPU) | ||
class PyDVSFileInputPM(PyLoihiProcessModel): | ||
event_frame_out: PyOutPort = LavaPyType(PyOutPort.VEC_DENSE, int) | ||
|
||
def __init__(self, proc_params): | ||
super().__init__(proc_params) | ||
self._true_height = proc_params["true_height"] | ||
self._true_width = proc_params["true_width"] | ||
self._true_shape = (self._true_width, self._true_height) | ||
self._file_path = proc_params["file_path"] | ||
self._aedat_file = AedatFile(self._file_path) | ||
self._event_stream = self._aedat_file["events"].numpy() | ||
self._frame_stream = self._aedat_file["frames"] | ||
self._flatten = proc_params["flatten"] | ||
self._down_sample_factor = proc_params["down_sample_factor"] | ||
self._down_sample_mode = proc_params["down_sample_mode"] | ||
self._down_sampled_height = \ | ||
self._true_height // self._down_sample_factor | ||
self._down_sampled_width = \ | ||
self._true_width // self._down_sample_factor | ||
self._down_sampled_shape = (self._down_sampled_width, | ||
self._down_sampled_height) | ||
self._num_steps = proc_params["num_steps"] | ||
|
||
def run_spk(self): | ||
# Get next events from .aedat4 file. | ||
events = self._event_stream.__next__() | ||
xs, ys, ps = events['x'], events['y'], events['polarity'] | ||
|
||
# Write events to event frame. | ||
event_frame = np.zeros(self._true_shape) | ||
event_frame[xs[ps == 0], ys[ps == 0]] = 1 | ||
event_frame[xs[ps == 1], ys[ps == 1]] = 1 | ||
|
||
# Downsample event frame. | ||
if self._down_sample_mode == "down_sampling": | ||
event_frame_small = \ | ||
event_frame[::self._down_sample_factor, | ||
::self._down_sample_factor] | ||
|
||
event_frame_small = \ | ||
event_frame_small[:self._down_sampled_height, | ||
:self._down_sampled_width] | ||
elif self._down_sample_mode == "max_pooling": | ||
event_frame_small = \ | ||
self._pool_2d(event_frame, kernel_size=self._down_sample_factor, | ||
stride=self._down_sample_factor, padding=0, | ||
pool_mode='max') | ||
elif self._down_sample_mode == "convolution": | ||
event_frame_small = \ | ||
self._convolution(event_frame) | ||
else: | ||
raise ValueError(f"Unknown down_sample_mode " | ||
f"{self._down_sample_mode}") | ||
|
||
if self._flatten: | ||
event_frame_small = event_frame_small.flatten() | ||
self.event_frame_out.send(event_frame_small) | ||
|
||
def _pool_2d(self, matrix: np.ndarray, kernel_size: int, stride: int, | ||
padding: int = 0, pool_mode: str = 'max'): | ||
# Padding | ||
padded_matrix = np.pad(matrix, padding, mode='constant') | ||
|
||
# Window view of A | ||
output_shape = ((padded_matrix.shape[0] - kernel_size) // stride + 1, | ||
(padded_matrix.shape[1] - kernel_size) // stride + 1) | ||
shape_w = (output_shape[0], output_shape[1], kernel_size, kernel_size) | ||
strides_w = (stride * padded_matrix.strides[0], | ||
stride * padded_matrix.strides[1], | ||
padded_matrix.strides[0], | ||
padded_matrix.strides[1]) | ||
matrix_w = as_strided(padded_matrix, shape_w, strides_w) | ||
|
||
# Return the result of pooling | ||
if pool_mode == 'max': | ||
return matrix_w.max(axis=(2, 3)) | ||
elif pool_mode == 'avg': | ||
return matrix_w.mean(axis=(2, 3)) | ||
|
||
def _convolution(self, matrix: np.ndarray, kernel_size: int = 3): | ||
kernel = np.ones((kernel_size, kernel_size)) | ||
event_frame_convolved = signal.convolve2d(matrix, kernel, mode="same") | ||
|
||
event_frame_small = \ | ||
event_frame_convolved[::self._down_sample_factor, | ||
::self._down_sample_factor] | ||
|
||
event_frame_small = \ | ||
event_frame_small[:self._down_sampled_width, | ||
:self._down_sampled_height] | ||
|
||
return event_frame_small |
Git LFS file not shown
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,156 @@ | ||
# Copyright (C) 2023 Intel Corporation | ||
# SPDX-License-Identifier: BSD-3-Clause | ||
# See: https://spdx.org/licenses/ | ||
|
||
import sys | ||
import threading | ||
import functools | ||
import multiprocessing | ||
try: | ||
from bokeh.plotting import figure, curdoc | ||
from bokeh.layouts import gridplot, Spacer | ||
from bokeh.models import LinearColorMapper, ColorBar, Title, Button | ||
from bokeh.models.ranges import DataRange1d | ||
except ModuleNotFoundError: | ||
print("Module 'bokeh' is not installed. Please install module 'bokeh' in" | ||
" order to run the motion tracking demo.") | ||
exit() | ||
from motion_tracker import MotionTracker | ||
from lava.utils.serialization import load | ||
from lava.utils import loihi | ||
|
||
loihi.use_slurm_host() | ||
|
||
# ========================================================================== | ||
# Parameters | ||
# ========================================================================== | ||
recv_pipe, send_pipe = multiprocessing.Pipe() | ||
num_steps = 3000 | ||
|
||
class BokehControl: | ||
"""Class which holds control state for Bokeh.""" | ||
stop_button_pressed: bool = False | ||
# ========================================================================== | ||
# Set up network | ||
# ========================================================================== | ||
|
||
_, executable = load("mt_executable.pickle") | ||
network = MotionTracker(send_pipe, | ||
num_steps, | ||
executable=executable) | ||
# ========================================================================== | ||
# Bokeh Helpers | ||
# ========================================================================== | ||
|
||
|
||
def callback_run() -> None: | ||
network.start() | ||
|
||
|
||
def callback_stop() -> None: | ||
BokehControl.stop_button_pressed = True | ||
network.stop() | ||
sys.exit() | ||
|
||
|
||
def create_plot(plot_base_width, | ||
data_shape, | ||
title, | ||
max_value=1) -> (figure, figure.image): | ||
x_range = DataRange1d(start=0, | ||
end=data_shape[0], | ||
bounds=(0, data_shape[0]), | ||
range_padding=50, | ||
range_padding_units='percent') | ||
y_range = DataRange1d(start=0, | ||
end=data_shape[1], | ||
bounds=(0, data_shape[1]), | ||
range_padding=50, | ||
range_padding_units='percent') | ||
|
||
pw = plot_base_width | ||
ph = int(pw * data_shape[1] / data_shape[0]) | ||
plot = figure(width=pw, | ||
height=ph, | ||
x_range=x_range, | ||
y_range=y_range, | ||
match_aspect=True, | ||
tooltips=[("x", "$x"), ("y", "$y"), ("value", "@image")], | ||
toolbar_location=None) | ||
|
||
image = plot.image([], x=0, y=0, dw=data_shape[0], dh=data_shape[1], | ||
palette="Viridis256", level="image") | ||
|
||
plot.add_layout(Title(text=title, align="center"), "above") | ||
|
||
x_grid = list(range(data_shape[0])) | ||
plot.xgrid[0].ticker = x_grid | ||
y_grid = list(range(data_shape[1])) | ||
plot.ygrid[0].ticker = y_grid | ||
plot.xgrid.grid_line_color = None | ||
plot.ygrid.grid_line_color = None | ||
|
||
color = LinearColorMapper(palette="Viridis256", low=0, high=max_value) | ||
image.glyph.color_mapper = color | ||
|
||
cb = ColorBar(color_mapper=color) | ||
plot.add_layout(cb, 'right') | ||
|
||
return plot, image | ||
|
||
|
||
# ========================================================================== | ||
# Instantiating Bokeh document | ||
# ========================================================================== | ||
bokeh_document = curdoc() | ||
|
||
# create plots | ||
dvs_frame_p, dvs_frame_im = create_plot( | ||
400, network.downsampled_shape, "DVS file input (max pooling)", | ||
max_value=10) | ||
dnf_multipeak_rates_p, dnf_multipeak_rates_im = create_plot( | ||
400, network.downsampled_shape, "DNF multi-peak (spike rates)") | ||
dnf_selective_rates_p, dnf_selective_rates_im = create_plot( | ||
400, network.downsampled_shape, "DNF selective (spike rates)") | ||
|
||
# add a button widget and configure with the call back | ||
button_run = Button(label="Run") | ||
button_run.on_click(callback_run) | ||
|
||
button_stop = Button(label="Close") | ||
button_stop.on_click(callback_stop) | ||
# finalize layout (with spacer as placeholder) | ||
spacer = Spacer(height=40) | ||
bokeh_document.add_root( | ||
gridplot([[button_run, None, button_stop], | ||
[None, spacer, None], | ||
[dvs_frame_p, dnf_multipeak_rates_p, dnf_selective_rates_p]], | ||
toolbar_options=dict(logo=None))) | ||
|
||
|
||
# ========================================================================== | ||
# Bokeh Update | ||
# ========================================================================== | ||
def update(dvs_frame_ds_image, | ||
dnf_multipeak_rates_ds_image, | ||
dnf_selective_rates_ds_image) -> None: | ||
dvs_frame_im.data_source.data["image"] = [dvs_frame_ds_image] | ||
dnf_multipeak_rates_im.data_source.data["image"] = \ | ||
[dnf_multipeak_rates_ds_image] | ||
dnf_selective_rates_im.data_source.data["image"] = \ | ||
[dnf_selective_rates_ds_image] | ||
|
||
|
||
# ========================================================================== | ||
# Bokeh Main loop | ||
# ========================================================================== | ||
def main_loop() -> None: | ||
while not BokehControl.stop_button_pressed: | ||
if recv_pipe.poll(): | ||
data_for_plot_dict = recv_pipe.recv() | ||
bokeh_document.add_next_tick_callback( | ||
functools.partial(update, **data_for_plot_dict)) | ||
|
||
|
||
thread = threading.Thread(target=main_loop) | ||
thread.start() |
Oops, something went wrong.