From d5be7054c90b35ff3b8eac1cff8548cbdc45783e Mon Sep 17 00:00:00 2001 From: PhilippPlank <32519998+PhilippPlank@users.noreply.github.com> Date: Tue, 19 Sep 2023 12:25:46 +0200 Subject: [PATCH] Motion tracking demo using bokeh for live visualization (#71) * motion tracking demo * use full imports and allow for blocking mode for tests * add recording * Update process.py * Update process.py * minor changes from review * minor changes * cleaning up lfs messup * add readme * add readme * minor changes * added more infos * Rename Readme.md to README.md --------- Co-authored-by: SveaMeyer13 --- demos/motion_tracking/.gitattributes | 2 + demos/motion_tracking/README.md | 42 ++++ .../motion_tracking/dvs_file_input/process.py | 149 +++++++++++ demos/motion_tracking/dvs_recording.aedat4 | 3 + demos/motion_tracking/main_motion_tracking.py | 156 ++++++++++++ demos/motion_tracking/motion_tracker.py | 233 ++++++++++++++++++ demos/motion_tracking/mt_executable.pickle | 3 + demos/motion_tracking/process_out/process.py | 60 +++++ demos/motion_tracking/rate_reader/process.py | 57 +++++ 9 files changed, 705 insertions(+) create mode 100644 demos/motion_tracking/.gitattributes create mode 100644 demos/motion_tracking/README.md create mode 100644 demos/motion_tracking/dvs_file_input/process.py create mode 100644 demos/motion_tracking/dvs_recording.aedat4 create mode 100644 demos/motion_tracking/main_motion_tracking.py create mode 100644 demos/motion_tracking/motion_tracker.py create mode 100644 demos/motion_tracking/mt_executable.pickle create mode 100644 demos/motion_tracking/process_out/process.py create mode 100644 demos/motion_tracking/rate_reader/process.py diff --git a/demos/motion_tracking/.gitattributes b/demos/motion_tracking/.gitattributes new file mode 100644 index 0000000..449aa90 --- /dev/null +++ b/demos/motion_tracking/.gitattributes @@ -0,0 +1,2 @@ +*.aedat4 filter=lfs diff=lfs merge=lfs -text +*.pickle filter=lfs diff=lfs merge=lfs -text diff --git a/demos/motion_tracking/README.md b/demos/motion_tracking/README.md new file mode 100644 index 0000000..db59906 --- /dev/null +++ b/demos/motion_tracking/README.md @@ -0,0 +1,42 @@ +# Dynamic Neural Fields - Demos + +This readme assumes that you have installed lava, lava-loihi and lava-dnf in the same virtual environment. Additionally, you should check if you downloaded the input recording which is stored using git lfs. Check if the file size is reasonable. + +```bash +lava-dnf/demos/motion_tracking$ ls -l +-rw-r--r-- 1 93353870 Sep 18 08:01 dvs_recording.aedat4 +``` +If the file size is only 133, then use the following command before running: +```bash +lava-dnf/demos/motion_tracking$ git lfs pull +``` + +### Running the demos +The demo will run in your browser via port-forwarding. Choose a random port_num between 10000 and 20000. +(This is to avoid that multiple users try to use the same port) + +#### Connect to external vlab with port-forwarding +```bash +ssh .research.intel-research.net -L 127.0.0.1::127.0.0.1: +``` + +#### Activate your virtual environment +Location of your virtual enviornment might differ. +```bash +source lava/lava_nx_env/bin/activate +``` + +#### Navigate to the motion_tracking demo: +```bash +cd lava-dnf/demos/motion_tracking +``` +#### Start the bokeh app +```bash +bokeh serve main_motion_tracking.py --port +``` + +open your browser and type: +http://localhost:/main_motion_tracking + +As the network is pre-compiled, the demo will appear immediately, and you just need to click the "run" button to start the demo. +It is currently not possible to interrupt the demo while it is running. Please wait for the demo to terminate and click the "close" button for all processes to terminate gracefully. diff --git a/demos/motion_tracking/dvs_file_input/process.py b/demos/motion_tracking/dvs_file_input/process.py new file mode 100644 index 0000000..0872fdd --- /dev/null +++ b/demos/motion_tracking/dvs_file_input/process.py @@ -0,0 +1,149 @@ +# Copyright (C) 2023 Intel Corporation +# SPDX-License-Identifier: BSD-3-Clause +# See: https://spdx.org/licenses/ + +import numpy as np +from numpy.lib.stride_tricks import as_strided +from scipy import signal +try: + from dv import AedatFile +except ModuleNotFoundError: + print("Module 'dv' is not installed. Please install module 'dv' in order" + " to use the process DVSFileInput.") + exit() + +from lava.magma.core.process.process import AbstractProcess +from lava.magma.core.process.ports.ports import OutPort +from lava.magma.core.sync.protocols.loihi_protocol import LoihiProtocol +from lava.magma.core.model.py.type import LavaPyType +from lava.magma.core.model.py.ports import PyOutPort +from lava.magma.core.resources import CPU +from lava.magma.core.model.py.model import PyLoihiProcessModel +from lava.magma.core.decorator import implements, requires + + + +class DVSFileInput(AbstractProcess): + """Process to read from .aedat4 file, downsample the frame in different + modes (max pooling, convolution, normal downsampling) and send out the + down sampled event frame.""" + def __init__(self, + true_height: int, + true_width: int, + file_path: str, + flatten: bool = False, + down_sample_factor: int = 1, + down_sample_mode: str = "down_sample", + num_steps=1) -> None: + super().__init__(true_height=true_height, + true_width=true_width, + file_path=file_path, + flatten=flatten, + down_sample_factor=down_sample_factor, + down_sample_mode=down_sample_mode, + num_steps=num_steps) + + down_sampled_height = true_height // down_sample_factor + down_sampled_width = true_width // down_sample_factor + + if flatten: + out_shape = (down_sampled_width * down_sampled_height,) + else: + out_shape = (down_sampled_width, down_sampled_height) + self.event_frame_out = OutPort(shape=out_shape) + + +@implements(proc=DVSFileInput, protocol=LoihiProtocol) +@requires(CPU) +class PyDVSFileInputPM(PyLoihiProcessModel): + event_frame_out: PyOutPort = LavaPyType(PyOutPort.VEC_DENSE, int) + + def __init__(self, proc_params): + super().__init__(proc_params) + self._true_height = proc_params["true_height"] + self._true_width = proc_params["true_width"] + self._true_shape = (self._true_width, self._true_height) + self._file_path = proc_params["file_path"] + self._aedat_file = AedatFile(self._file_path) + self._event_stream = self._aedat_file["events"].numpy() + self._frame_stream = self._aedat_file["frames"] + self._flatten = proc_params["flatten"] + self._down_sample_factor = proc_params["down_sample_factor"] + self._down_sample_mode = proc_params["down_sample_mode"] + self._down_sampled_height = \ + self._true_height // self._down_sample_factor + self._down_sampled_width = \ + self._true_width // self._down_sample_factor + self._down_sampled_shape = (self._down_sampled_width, + self._down_sampled_height) + self._num_steps = proc_params["num_steps"] + + def run_spk(self): + # Get next events from .aedat4 file. + events = self._event_stream.__next__() + xs, ys, ps = events['x'], events['y'], events['polarity'] + + # Write events to event frame. + event_frame = np.zeros(self._true_shape) + event_frame[xs[ps == 0], ys[ps == 0]] = 1 + event_frame[xs[ps == 1], ys[ps == 1]] = 1 + + # Downsample event frame. + if self._down_sample_mode == "down_sampling": + event_frame_small = \ + event_frame[::self._down_sample_factor, + ::self._down_sample_factor] + + event_frame_small = \ + event_frame_small[:self._down_sampled_height, + :self._down_sampled_width] + elif self._down_sample_mode == "max_pooling": + event_frame_small = \ + self._pool_2d(event_frame, kernel_size=self._down_sample_factor, + stride=self._down_sample_factor, padding=0, + pool_mode='max') + elif self._down_sample_mode == "convolution": + event_frame_small = \ + self._convolution(event_frame) + else: + raise ValueError(f"Unknown down_sample_mode " + f"{self._down_sample_mode}") + + if self._flatten: + event_frame_small = event_frame_small.flatten() + self.event_frame_out.send(event_frame_small) + + def _pool_2d(self, matrix: np.ndarray, kernel_size: int, stride: int, + padding: int = 0, pool_mode: str = 'max'): + # Padding + padded_matrix = np.pad(matrix, padding, mode='constant') + + # Window view of A + output_shape = ((padded_matrix.shape[0] - kernel_size) // stride + 1, + (padded_matrix.shape[1] - kernel_size) // stride + 1) + shape_w = (output_shape[0], output_shape[1], kernel_size, kernel_size) + strides_w = (stride * padded_matrix.strides[0], + stride * padded_matrix.strides[1], + padded_matrix.strides[0], + padded_matrix.strides[1]) + matrix_w = as_strided(padded_matrix, shape_w, strides_w) + + # Return the result of pooling + if pool_mode == 'max': + return matrix_w.max(axis=(2, 3)) + elif pool_mode == 'avg': + return matrix_w.mean(axis=(2, 3)) + + def _convolution(self, matrix: np.ndarray, kernel_size: int = 3): + kernel = np.ones((kernel_size, kernel_size)) + event_frame_convolved = signal.convolve2d(matrix, kernel, mode="same") + + event_frame_small = \ + event_frame_convolved[::self._down_sample_factor, + ::self._down_sample_factor] + + event_frame_small = \ + event_frame_small[:self._down_sampled_width, + :self._down_sampled_height] + + return event_frame_small diff --git a/demos/motion_tracking/dvs_recording.aedat4 b/demos/motion_tracking/dvs_recording.aedat4 new file mode 100644 index 0000000..13a6d44 --- /dev/null +++ b/demos/motion_tracking/dvs_recording.aedat4 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:73d20a59a2328c45a0356c1a9952e3f32153ed8ddaadc56cad346d4a7cb5bf7a +size 93353870 diff --git a/demos/motion_tracking/main_motion_tracking.py b/demos/motion_tracking/main_motion_tracking.py new file mode 100644 index 0000000..c4c963d --- /dev/null +++ b/demos/motion_tracking/main_motion_tracking.py @@ -0,0 +1,156 @@ +# Copyright (C) 2023 Intel Corporation +# SPDX-License-Identifier: BSD-3-Clause +# See: https://spdx.org/licenses/ + +import sys +import threading +import functools +import multiprocessing +try: + from bokeh.plotting import figure, curdoc + from bokeh.layouts import gridplot, Spacer + from bokeh.models import LinearColorMapper, ColorBar, Title, Button + from bokeh.models.ranges import DataRange1d +except ModuleNotFoundError: + print("Module 'bokeh' is not installed. Please install module 'bokeh' in" + " order to run the motion tracking demo.") + exit() +from motion_tracker import MotionTracker +from lava.utils.serialization import load +from lava.utils import loihi + +loihi.use_slurm_host() + +# ========================================================================== +# Parameters +# ========================================================================== +recv_pipe, send_pipe = multiprocessing.Pipe() +num_steps = 3000 + +class BokehControl: + """Class which holds control state for Bokeh.""" + stop_button_pressed: bool = False +# ========================================================================== +# Set up network +# ========================================================================== + +_, executable = load("mt_executable.pickle") +network = MotionTracker(send_pipe, + num_steps, + executable=executable) +# ========================================================================== +# Bokeh Helpers +# ========================================================================== + + +def callback_run() -> None: + network.start() + + +def callback_stop() -> None: + BokehControl.stop_button_pressed = True + network.stop() + sys.exit() + + +def create_plot(plot_base_width, + data_shape, + title, + max_value=1) -> (figure, figure.image): + x_range = DataRange1d(start=0, + end=data_shape[0], + bounds=(0, data_shape[0]), + range_padding=50, + range_padding_units='percent') + y_range = DataRange1d(start=0, + end=data_shape[1], + bounds=(0, data_shape[1]), + range_padding=50, + range_padding_units='percent') + + pw = plot_base_width + ph = int(pw * data_shape[1] / data_shape[0]) + plot = figure(width=pw, + height=ph, + x_range=x_range, + y_range=y_range, + match_aspect=True, + tooltips=[("x", "$x"), ("y", "$y"), ("value", "@image")], + toolbar_location=None) + + image = plot.image([], x=0, y=0, dw=data_shape[0], dh=data_shape[1], + palette="Viridis256", level="image") + + plot.add_layout(Title(text=title, align="center"), "above") + + x_grid = list(range(data_shape[0])) + plot.xgrid[0].ticker = x_grid + y_grid = list(range(data_shape[1])) + plot.ygrid[0].ticker = y_grid + plot.xgrid.grid_line_color = None + plot.ygrid.grid_line_color = None + + color = LinearColorMapper(palette="Viridis256", low=0, high=max_value) + image.glyph.color_mapper = color + + cb = ColorBar(color_mapper=color) + plot.add_layout(cb, 'right') + + return plot, image + + +# ========================================================================== +# Instantiating Bokeh document +# ========================================================================== +bokeh_document = curdoc() + +# create plots +dvs_frame_p, dvs_frame_im = create_plot( + 400, network.downsampled_shape, "DVS file input (max pooling)", + max_value=10) +dnf_multipeak_rates_p, dnf_multipeak_rates_im = create_plot( + 400, network.downsampled_shape, "DNF multi-peak (spike rates)") +dnf_selective_rates_p, dnf_selective_rates_im = create_plot( + 400, network.downsampled_shape, "DNF selective (spike rates)") + +# add a button widget and configure with the call back +button_run = Button(label="Run") +button_run.on_click(callback_run) + +button_stop = Button(label="Close") +button_stop.on_click(callback_stop) +# finalize layout (with spacer as placeholder) +spacer = Spacer(height=40) +bokeh_document.add_root( + gridplot([[button_run, None, button_stop], + [None, spacer, None], + [dvs_frame_p, dnf_multipeak_rates_p, dnf_selective_rates_p]], + toolbar_options=dict(logo=None))) + + +# ========================================================================== +# Bokeh Update +# ========================================================================== +def update(dvs_frame_ds_image, + dnf_multipeak_rates_ds_image, + dnf_selective_rates_ds_image) -> None: + dvs_frame_im.data_source.data["image"] = [dvs_frame_ds_image] + dnf_multipeak_rates_im.data_source.data["image"] = \ + [dnf_multipeak_rates_ds_image] + dnf_selective_rates_im.data_source.data["image"] = \ + [dnf_selective_rates_ds_image] + + +# ========================================================================== +# Bokeh Main loop +# ========================================================================== +def main_loop() -> None: + while not BokehControl.stop_button_pressed: + if recv_pipe.poll(): + data_for_plot_dict = recv_pipe.recv() + bokeh_document.add_next_tick_callback( + functools.partial(update, **data_for_plot_dict)) + + +thread = threading.Thread(target=main_loop) +thread.start() diff --git a/demos/motion_tracking/motion_tracker.py b/demos/motion_tracking/motion_tracker.py new file mode 100644 index 0000000..0705954 --- /dev/null +++ b/demos/motion_tracking/motion_tracker.py @@ -0,0 +1,233 @@ +# Copyright (C) 2023 Intel Corporation +# SPDX-License-Identifier: BSD-3-Clause +# See: https://spdx.org/licenses/ + +import typing as ty +from multiprocessing import Pipe +from lava.lib.dnf.connect.connect import connect +from lava.magma.core.run_conditions import RunSteps +from demos.motion_tracking.dvs_file_input.process import DVSFileInput, PyDVSFileInputPM +from lava.proc.lif.process import LIF +from lava.proc.embedded_io.spike import NxToPyAdapter, PyToNxAdapter +from lava.lib.dnf.kernels.kernels import MultiPeakKernel, SelectiveKernel +from lava.lib.dnf.operations.operations import Convolution, Weights +from demos.motion_tracking.process_out.process import ProcessOut, ProcessOutModel +from lava.magma.core.run_configs import Loihi2HwCfg +from demos.motion_tracking.rate_reader.process import RateReader +from lava.magma.compiler.compiler import Compiler +from lava.magma.core.process.message_interface_enum import ActorType +from lava.magma.runtime.runtime import Runtime +from lava.magma.compiler.executable import Executable +from lava.magma.compiler.subcompilers.nc.ncproc_compiler import CompilerOptions + +CompilerOptions.verbose = True + +# Default configs for the motion tracking network +dvs_file_input_default_config = { + "true_height": 180, + "true_width": 240, + "file_path": "dvs_recording.aedat4", + "flatten": False, + "downsample_factor": 8, + "downsample_mode": "max_pooling" +} + +multipeak_dnf_default_config = { + "in_conn": { + "weight": 8 + }, + "lif": {"du": 2000, + "dv": 2000, + "vth": 30 + + }, + "rec_conn": { + "amp_exc": 14, + "width_exc": [5, 5], + "amp_inh": -10, + "width_inh": [9, 9] + }, + "out_conn": { + "weight": 20, + } +} + +selective_dnf_default_config = { + "lif": {"du": 809, + "dv": 2047, + "vth": 30 + }, + "rec_conn": { + "amp_exc": 7, + "width_exc": [7, 7], + "global_inh": -5 + } +} + +rate_reader_default_config = { + "buffer_size": 10, +} + + +class MotionTracker: + """Class to setup the motion tracking network, compile it, and initialize + its runtime. The network topology looks as follows: + + DVS_Input -> PytoNxAdapter -> Multipeak DNF -> Selective DNF + | | | + | NxtoPyAdapter NxtoPyAdapter + | | | + | RateReader RateReader + | | | + ----------------------------> DataRelay <-------- + """ + + def __init__(self, + send_pipe: type(Pipe), + num_steps: int, + blocking: ty.Optional[bool] = False, + dvs_file_input_config: ty.Optional[dict] = None, + multipeak_dnf_config: ty.Optional[dict] = None, + selective_dnf_config: ty.Optional[dict] = None, + rate_reader_config: ty.Optional[dict] = None, + executable: ty.Optional[Executable] = None) -> None: + + # Initialize input file/data + dvs_file_input_config = \ + dvs_file_input_config or dvs_file_input_default_config + multipeak_dnf_config = \ + multipeak_dnf_config or multipeak_dnf_default_config + selective_dnf_config = \ + selective_dnf_config or selective_dnf_default_config + rate_reader_config = \ + rate_reader_config or rate_reader_default_config + + # Initialize input params + self.true_shape = (dvs_file_input_config["true_width"], + dvs_file_input_config["true_height"]) + self.file_path = dvs_file_input_config["file_path"] + self.downsample_factor = dvs_file_input_config["downsample_factor"] + self.downsample_mode = dvs_file_input_config["downsample_mode"] + self.flatten = dvs_file_input_config["flatten"] + self.downsampled_shape = (self.true_shape[0] // self.downsample_factor, + self.true_shape[1] // self.downsample_factor) + + # Initialize multipeak dnf params + self.multipeak_in_params = multipeak_dnf_config["in_conn"] + self.multipeak_lif_params = multipeak_dnf_config["lif"] + self.multipeak_rec_params = multipeak_dnf_config["rec_conn"] + self.multipeak_out_params = multipeak_dnf_config["out_conn"] + + # Initialize selective dnf params + self.selective_lif_params = selective_dnf_config["lif"] + self.selective_rec_params = selective_dnf_config["rec_conn"] + + # Intialize rate reader params + self.buffer_size_rate_reader = rate_reader_config["buffer_size"] + + # Initialize send_pipe + self.send_pipe = send_pipe + + self._create_processes() + self._make_connections() + + # Runtime Creation and Compilation + exception_pm_map = { + DVSFileInput: PyDVSFileInputPM, + ProcessOut: ProcessOutModel + } + run_cfg = Loihi2HwCfg(exception_proc_model_map=exception_pm_map) + self.num_steps = num_steps + self.blocking = blocking + + # Compilation + compiler = Compiler() + + if executable is None: + executable = compiler.compile(self.dvs_file_input, run_cfg=run_cfg) + + # Initialize runtime + mp = ActorType.MultiProcessing + self.runtime = Runtime(exe=executable, + message_infrastructure_type=mp) + self.runtime.initialize() + + def _create_processes(self) -> None: + # Instantiate Processes Running on CPU + self.dvs_file_input = \ + DVSFileInput(true_height=self.true_shape[1], + true_width=self.true_shape[0], + file_path=self.file_path, + flatten=self.flatten, + down_sample_factor=self.downsample_factor, + down_sample_mode=self.downsample_mode) + + self.rate_reader_multi_peak = \ + RateReader(shape=self.downsampled_shape, + buffer_size=self.buffer_size_rate_reader) + + self.rate_reader_selective = \ + RateReader(shape=self.downsampled_shape, + buffer_size=self.buffer_size_rate_reader) + + # sends data to pipe for plotting + self.data_relayer = ProcessOut(shape_dvs_frame=self.downsampled_shape, + shape_dnf=self.downsampled_shape, + send_pipe=self.send_pipe) + + # Instantiate C-Processes Running on LMT + self.c_injector = PyToNxAdapter(shape=self.downsampled_shape) + self.c_spike_reader_multi_peak = NxToPyAdapter( + shape=self.downsampled_shape) + self.c_spike_reader_selective = NxToPyAdapter( + shape=self.downsampled_shape) + + # Instantiate Processes Running on Loihi 2 + self.dnf_multipeak = LIF(shape=self.downsampled_shape, + **self.multipeak_lif_params) + self.dnf_selective = LIF(shape=self.downsampled_shape, + **self.selective_lif_params) + + def _make_connections(self) -> None: + # Connecting Input Processes + self.dvs_file_input.event_frame_out.connect(self.c_injector.inp) + + # Connections around multipeak dnf + connect(self.c_injector.out, self.dnf_multipeak.a_in, + ops=[Weights(**self.multipeak_in_params)]) + connect(self.dnf_multipeak.s_out, self.dnf_multipeak.a_in, + ops=[Convolution(MultiPeakKernel(**self.multipeak_rec_params))]) + connect(self.dnf_multipeak.s_out, self.dnf_selective.a_in, + ops=[Weights(**self.multipeak_out_params)]) + + # Connections around selective dnf + connect(self.dnf_selective.s_out, self.dnf_selective.a_in, + ops=[Convolution(SelectiveKernel(**self.selective_rec_params))]) + + # Connect C Reader Processes + self.dnf_multipeak.s_out.connect( + self.c_spike_reader_multi_peak.inp) + self.dnf_selective.s_out.connect( + self.c_spike_reader_selective.inp) + + # Connect RateReaders + self.c_spike_reader_multi_peak.out.connect( + self.rate_reader_multi_peak.in_port) + self.c_spike_reader_selective.out.connect( + self.rate_reader_selective.in_port) + + # Connect ProcessOut (data relayer) + self.dvs_file_input.event_frame_out.reshape( + new_shape=self.downsampled_shape).connect( + self.data_relayer.dvs_frame_port) + self.rate_reader_multi_peak.out_port.connect( + self.data_relayer.dnf_multipeak_rates_port) + self.rate_reader_selective.out_port.connect( + self.data_relayer.dnf_selective_rates_port) + + def start(self) -> None: + self.runtime.start(RunSteps(num_steps=self.num_steps, blocking=self.blocking)) + + def stop(self) -> None: + self.runtime.wait() + self.runtime.stop() diff --git a/demos/motion_tracking/mt_executable.pickle b/demos/motion_tracking/mt_executable.pickle new file mode 100644 index 0000000..03226f5 --- /dev/null +++ b/demos/motion_tracking/mt_executable.pickle @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:6cfb59e502499d885c7c6d3b036a1fa5f0361d1b32f308071f02e76e2820acb3 +size 100804179 diff --git a/demos/motion_tracking/process_out/process.py b/demos/motion_tracking/process_out/process.py new file mode 100644 index 0000000..7b3c148 --- /dev/null +++ b/demos/motion_tracking/process_out/process.py @@ -0,0 +1,60 @@ +# Copyright (C) 2023 Intel Corporation +# SPDX-License-Identifier: BSD-3-Clause +# See: https://spdx.org/licenses/ + +import numpy as np +from lava.magma.core.process.process import AbstractProcess +from lava.magma.core.process.ports.ports import InPort +from lava.magma.core.sync.protocols.loihi_protocol import LoihiProtocol +from lava.magma.core.model.py.ports import PyInPort +from lava.magma.core.model.py.type import LavaPyType +from lava.magma.core.resources import CPU +from lava.magma.core.decorator import implements, requires +from lava.magma.core.model.py.model import PyLoihiProcessModel + + +class ProcessOut(AbstractProcess): + """Process that receives (1) the raw DVS events, (2) the spike rates + of the selective as well as (3) the multi-peak DNF per pixel. It sends + these values through a multiprocessing pipe (rather than a Lava OutPort) + to allow for plotting." + """ + def __init__(self, + shape_dvs_frame, + shape_dnf, + send_pipe) -> None: + super().__init__(shape_dvs_frame=shape_dvs_frame, + shape_dnf=shape_dnf, + send_pipe=send_pipe) + self.dvs_frame_port = InPort(shape=shape_dvs_frame) + self.dnf_multipeak_rates_port = InPort(shape=shape_dnf) + self.dnf_selective_rates_port = InPort(shape=shape_dnf) + + +@implements(proc=ProcessOut, protocol=LoihiProtocol) +@requires(CPU) +class ProcessOutModel(PyLoihiProcessModel): + dvs_frame_port: PyInPort = LavaPyType(PyInPort.VEC_DENSE, float) + dnf_multipeak_rates_port: PyInPort = LavaPyType(PyInPort.VEC_DENSE, float) + dnf_selective_rates_port: PyInPort = LavaPyType(PyInPort.VEC_DENSE, float) + + def __init__(self, proc_params) -> None: + super().__init__(proc_params) + self._send_pipe = proc_params["send_pipe"] + + def run_spk(self) -> None: + dvs_frame = self.dvs_frame_port.recv() + dnf_multipeak_rates = self.dnf_multipeak_rates_port.recv() + dnf_selective_rates = self.dnf_selective_rates_port.recv() + + dvs_frame_ds_image = np.rot90(dvs_frame) + dnf_multipeak_rates_ds_image = np.rot90(dnf_multipeak_rates) + dnf_selective_rates_ds_image = np.rot90(dnf_selective_rates) + + data_dict = { + "dvs_frame_ds_image": dvs_frame_ds_image, + "dnf_multipeak_rates_ds_image": dnf_multipeak_rates_ds_image, + "dnf_selective_rates_ds_image": dnf_selective_rates_ds_image, + } + + self._send_pipe.send(data_dict) diff --git a/demos/motion_tracking/rate_reader/process.py b/demos/motion_tracking/rate_reader/process.py new file mode 100644 index 0000000..1a403e0 --- /dev/null +++ b/demos/motion_tracking/rate_reader/process.py @@ -0,0 +1,57 @@ +# Copyright (C) 2022 Intel Corporation +# SPDX-License-Identifier: BSD-3-Clause +# See: https://spdx.org/licenses/ + +import numpy as np +from lava.magma.core.process.process import AbstractProcess +from lava.magma.core.process.ports.ports import InPort, OutPort +from lava.magma.core.process.variable import Var +from lava.magma.core.sync.protocols.loihi_protocol import LoihiProtocol +from lava.magma.core.model.py.type import LavaPyType +from lava.magma.core.model.py.ports import PyInPort, PyOutPort +from lava.magma.core.resources import CPU +from lava.magma.core.model.py.model import PyLoihiProcessModel +from lava.magma.core.decorator import implements, requires + + +class RateReader(AbstractProcess): + """ + Process that stores recent spikes in a buffer and computes the spike rate + at each timestep. + """ + def __init__(self, shape, buffer_size, num_steps=1): + super().__init__(shape=shape, + buffer_size=buffer_size, + num_steps=num_steps) + self.in_port = InPort(shape) + self.buffer = Var(shape=shape + (buffer_size,)) + self.rate = Var(shape=shape, init=0) + self.out_port = OutPort(shape) + + +@implements(proc=RateReader, protocol=LoihiProtocol) +@requires(CPU) +class PyRateReaderPMDense(PyLoihiProcessModel): + in_port: PyInPort = LavaPyType(PyInPort.VEC_DENSE, int, 8) + buffer: np.ndarray = LavaPyType(np.ndarray, np.int32) + rate: np.ndarray = LavaPyType(np.ndarray, float) + out_port: PyOutPort = LavaPyType(PyOutPort.VEC_DENSE, float) + + def __init__(self, proc_params): + super().__init__(proc_params) + self._buffer_size = proc_params["buffer_size"] + + def post_guard(self): + # Ensures that run_post_mgmt runs after run_spk at every + # time step. + return True + + def run_post_mgmt(self): + # Runs after run_spk in every time step and computes the + # spike rate from the buffer. + spikes = self.in_port.recv() + self.buffer[..., (self.time_step - 1) % self._buffer_size] = spikes + self.rate = np.mean(self.buffer, axis=-1) + + def run_spk(self): + self.out_port.send(self.rate)