Skip to content

Commit

Permalink
Motion tracking demo using bokeh for live visualization (#71)
Browse files Browse the repository at this point in the history
* motion tracking demo

* use full imports and allow for blocking mode for tests

* add recording

* Update process.py

* Update process.py

* minor changes from review

* minor changes

* cleaning up lfs messup

* add readme

* add readme

* minor changes

* added more infos

* Rename Readme.md to README.md

---------

Co-authored-by: SveaMeyer13 <svea.meyer@tum.de>
  • Loading branch information
PhilippPlank and SveaMeyer13 authored Sep 19, 2023
1 parent bd7f344 commit d5be705
Show file tree
Hide file tree
Showing 9 changed files with 705 additions and 0 deletions.
2 changes: 2 additions & 0 deletions demos/motion_tracking/.gitattributes
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
*.aedat4 filter=lfs diff=lfs merge=lfs -text
*.pickle filter=lfs diff=lfs merge=lfs -text
42 changes: 42 additions & 0 deletions demos/motion_tracking/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
# Dynamic Neural Fields - Demos

This readme assumes that you have installed lava, lava-loihi and lava-dnf in the same virtual environment. Additionally, you should check if you downloaded the input recording which is stored using git lfs. Check if the file size is reasonable.

```bash
lava-dnf/demos/motion_tracking$ ls -l
-rw-r--r-- 1 <user> <group> 93353870 Sep 18 08:01 dvs_recording.aedat4
```
If the file size is only 133, then use the following command before running:
```bash
lava-dnf/demos/motion_tracking$ git lfs pull
```

### Running the demos
The demo will run in your browser via port-forwarding. Choose a random port_num between 10000 and 20000.
(This is to avoid that multiple users try to use the same port)

#### Connect to external vlab with port-forwarding
```bash
ssh <my-vm>.research.intel-research.net -L 127.0.0.1:<port_num>:127.0.0.1:<port_num>
```

#### Activate your virtual environment
Location of your virtual enviornment might differ.
```bash
source lava/lava_nx_env/bin/activate
```

#### Navigate to the motion_tracking demo:
```bash
cd lava-dnf/demos/motion_tracking
```
#### Start the bokeh app
```bash
bokeh serve main_motion_tracking.py --port <port_num>
```

open your browser and type:
http://localhost:<port_num>/main_motion_tracking

As the network is pre-compiled, the demo will appear immediately, and you just need to click the "run" button to start the demo.
It is currently not possible to interrupt the demo while it is running. Please wait for the demo to terminate and click the "close" button for all processes to terminate gracefully.
149 changes: 149 additions & 0 deletions demos/motion_tracking/dvs_file_input/process.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
# Copyright (C) 2023 Intel Corporation
# SPDX-License-Identifier: BSD-3-Clause
# See: https://spdx.org/licenses/

import numpy as np
from numpy.lib.stride_tricks import as_strided
from scipy import signal
try:
from dv import AedatFile
except ModuleNotFoundError:
print("Module 'dv' is not installed. Please install module 'dv' in order"
" to use the process DVSFileInput.")
exit()

from lava.magma.core.process.process import AbstractProcess
from lava.magma.core.process.ports.ports import OutPort
from lava.magma.core.sync.protocols.loihi_protocol import LoihiProtocol
from lava.magma.core.model.py.type import LavaPyType
from lava.magma.core.model.py.ports import PyOutPort
from lava.magma.core.resources import CPU
from lava.magma.core.model.py.model import PyLoihiProcessModel
from lava.magma.core.decorator import implements, requires



class DVSFileInput(AbstractProcess):
"""Process to read from .aedat4 file, downsample the frame in different
modes (max pooling, convolution, normal downsampling) and send out the
down sampled event frame."""
def __init__(self,
true_height: int,
true_width: int,
file_path: str,
flatten: bool = False,
down_sample_factor: int = 1,
down_sample_mode: str = "down_sample",
num_steps=1) -> None:
super().__init__(true_height=true_height,
true_width=true_width,
file_path=file_path,
flatten=flatten,
down_sample_factor=down_sample_factor,
down_sample_mode=down_sample_mode,
num_steps=num_steps)

down_sampled_height = true_height // down_sample_factor
down_sampled_width = true_width // down_sample_factor

if flatten:
out_shape = (down_sampled_width * down_sampled_height,)
else:
out_shape = (down_sampled_width, down_sampled_height)
self.event_frame_out = OutPort(shape=out_shape)


@implements(proc=DVSFileInput, protocol=LoihiProtocol)
@requires(CPU)
class PyDVSFileInputPM(PyLoihiProcessModel):
event_frame_out: PyOutPort = LavaPyType(PyOutPort.VEC_DENSE, int)

def __init__(self, proc_params):
super().__init__(proc_params)
self._true_height = proc_params["true_height"]
self._true_width = proc_params["true_width"]
self._true_shape = (self._true_width, self._true_height)
self._file_path = proc_params["file_path"]
self._aedat_file = AedatFile(self._file_path)
self._event_stream = self._aedat_file["events"].numpy()
self._frame_stream = self._aedat_file["frames"]
self._flatten = proc_params["flatten"]
self._down_sample_factor = proc_params["down_sample_factor"]
self._down_sample_mode = proc_params["down_sample_mode"]
self._down_sampled_height = \
self._true_height // self._down_sample_factor
self._down_sampled_width = \
self._true_width // self._down_sample_factor
self._down_sampled_shape = (self._down_sampled_width,
self._down_sampled_height)
self._num_steps = proc_params["num_steps"]

def run_spk(self):
# Get next events from .aedat4 file.
events = self._event_stream.__next__()
xs, ys, ps = events['x'], events['y'], events['polarity']

# Write events to event frame.
event_frame = np.zeros(self._true_shape)
event_frame[xs[ps == 0], ys[ps == 0]] = 1
event_frame[xs[ps == 1], ys[ps == 1]] = 1

# Downsample event frame.
if self._down_sample_mode == "down_sampling":
event_frame_small = \
event_frame[::self._down_sample_factor,
::self._down_sample_factor]

event_frame_small = \
event_frame_small[:self._down_sampled_height,
:self._down_sampled_width]
elif self._down_sample_mode == "max_pooling":
event_frame_small = \
self._pool_2d(event_frame, kernel_size=self._down_sample_factor,
stride=self._down_sample_factor, padding=0,
pool_mode='max')
elif self._down_sample_mode == "convolution":
event_frame_small = \
self._convolution(event_frame)
else:
raise ValueError(f"Unknown down_sample_mode "
f"{self._down_sample_mode}")

if self._flatten:
event_frame_small = event_frame_small.flatten()
self.event_frame_out.send(event_frame_small)

def _pool_2d(self, matrix: np.ndarray, kernel_size: int, stride: int,
padding: int = 0, pool_mode: str = 'max'):
# Padding
padded_matrix = np.pad(matrix, padding, mode='constant')

# Window view of A
output_shape = ((padded_matrix.shape[0] - kernel_size) // stride + 1,
(padded_matrix.shape[1] - kernel_size) // stride + 1)
shape_w = (output_shape[0], output_shape[1], kernel_size, kernel_size)
strides_w = (stride * padded_matrix.strides[0],
stride * padded_matrix.strides[1],
padded_matrix.strides[0],
padded_matrix.strides[1])
matrix_w = as_strided(padded_matrix, shape_w, strides_w)

# Return the result of pooling
if pool_mode == 'max':
return matrix_w.max(axis=(2, 3))
elif pool_mode == 'avg':
return matrix_w.mean(axis=(2, 3))

def _convolution(self, matrix: np.ndarray, kernel_size: int = 3):
kernel = np.ones((kernel_size, kernel_size))
event_frame_convolved = signal.convolve2d(matrix, kernel, mode="same")

event_frame_small = \
event_frame_convolved[::self._down_sample_factor,
::self._down_sample_factor]

event_frame_small = \
event_frame_small[:self._down_sampled_width,
:self._down_sampled_height]

return event_frame_small
3 changes: 3 additions & 0 deletions demos/motion_tracking/dvs_recording.aedat4
Git LFS file not shown
156 changes: 156 additions & 0 deletions demos/motion_tracking/main_motion_tracking.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,156 @@
# Copyright (C) 2023 Intel Corporation
# SPDX-License-Identifier: BSD-3-Clause
# See: https://spdx.org/licenses/

import sys
import threading
import functools
import multiprocessing
try:
from bokeh.plotting import figure, curdoc
from bokeh.layouts import gridplot, Spacer
from bokeh.models import LinearColorMapper, ColorBar, Title, Button
from bokeh.models.ranges import DataRange1d
except ModuleNotFoundError:
print("Module 'bokeh' is not installed. Please install module 'bokeh' in"
" order to run the motion tracking demo.")
exit()
from motion_tracker import MotionTracker
from lava.utils.serialization import load
from lava.utils import loihi

loihi.use_slurm_host()

# ==========================================================================
# Parameters
# ==========================================================================
recv_pipe, send_pipe = multiprocessing.Pipe()
num_steps = 3000

class BokehControl:
"""Class which holds control state for Bokeh."""
stop_button_pressed: bool = False
# ==========================================================================
# Set up network
# ==========================================================================

_, executable = load("mt_executable.pickle")
network = MotionTracker(send_pipe,
num_steps,
executable=executable)
# ==========================================================================
# Bokeh Helpers
# ==========================================================================


def callback_run() -> None:
network.start()


def callback_stop() -> None:
BokehControl.stop_button_pressed = True
network.stop()
sys.exit()


def create_plot(plot_base_width,
data_shape,
title,
max_value=1) -> (figure, figure.image):
x_range = DataRange1d(start=0,
end=data_shape[0],
bounds=(0, data_shape[0]),
range_padding=50,
range_padding_units='percent')
y_range = DataRange1d(start=0,
end=data_shape[1],
bounds=(0, data_shape[1]),
range_padding=50,
range_padding_units='percent')

pw = plot_base_width
ph = int(pw * data_shape[1] / data_shape[0])
plot = figure(width=pw,
height=ph,
x_range=x_range,
y_range=y_range,
match_aspect=True,
tooltips=[("x", "$x"), ("y", "$y"), ("value", "@image")],
toolbar_location=None)

image = plot.image([], x=0, y=0, dw=data_shape[0], dh=data_shape[1],
palette="Viridis256", level="image")

plot.add_layout(Title(text=title, align="center"), "above")

x_grid = list(range(data_shape[0]))
plot.xgrid[0].ticker = x_grid
y_grid = list(range(data_shape[1]))
plot.ygrid[0].ticker = y_grid
plot.xgrid.grid_line_color = None
plot.ygrid.grid_line_color = None

color = LinearColorMapper(palette="Viridis256", low=0, high=max_value)
image.glyph.color_mapper = color

cb = ColorBar(color_mapper=color)
plot.add_layout(cb, 'right')

return plot, image


# ==========================================================================
# Instantiating Bokeh document
# ==========================================================================
bokeh_document = curdoc()

# create plots
dvs_frame_p, dvs_frame_im = create_plot(
400, network.downsampled_shape, "DVS file input (max pooling)",
max_value=10)
dnf_multipeak_rates_p, dnf_multipeak_rates_im = create_plot(
400, network.downsampled_shape, "DNF multi-peak (spike rates)")
dnf_selective_rates_p, dnf_selective_rates_im = create_plot(
400, network.downsampled_shape, "DNF selective (spike rates)")

# add a button widget and configure with the call back
button_run = Button(label="Run")
button_run.on_click(callback_run)

button_stop = Button(label="Close")
button_stop.on_click(callback_stop)
# finalize layout (with spacer as placeholder)
spacer = Spacer(height=40)
bokeh_document.add_root(
gridplot([[button_run, None, button_stop],
[None, spacer, None],
[dvs_frame_p, dnf_multipeak_rates_p, dnf_selective_rates_p]],
toolbar_options=dict(logo=None)))


# ==========================================================================
# Bokeh Update
# ==========================================================================
def update(dvs_frame_ds_image,
dnf_multipeak_rates_ds_image,
dnf_selective_rates_ds_image) -> None:
dvs_frame_im.data_source.data["image"] = [dvs_frame_ds_image]
dnf_multipeak_rates_im.data_source.data["image"] = \
[dnf_multipeak_rates_ds_image]
dnf_selective_rates_im.data_source.data["image"] = \
[dnf_selective_rates_ds_image]


# ==========================================================================
# Bokeh Main loop
# ==========================================================================
def main_loop() -> None:
while not BokehControl.stop_button_pressed:
if recv_pipe.poll():
data_for_plot_dict = recv_pipe.recv()
bokeh_document.add_next_tick_callback(
functools.partial(update, **data_for_plot_dict))


thread = threading.Thread(target=main_loop)
thread.start()
Loading

0 comments on commit d5be705

Please sign in to comment.