From 881d87e311f1485b51521111e1e617304d1859db Mon Sep 17 00:00:00 2001 From: smm-ncl <145007783+smm-ncl@users.noreply.github.com> Date: Tue, 23 Jan 2024 18:22:46 +0100 Subject: [PATCH] SigmaS4Delta Neuronmodel and Layer with Unittests (#830) * first wokring version * S4D model cleaned * update license * fix imports * linting * incorporate reviews * update docstring --- src/lava/proc/s4d/models.py | 169 ++++++++++++++++++++++++ src/lava/proc/s4d/process.py | 167 ++++++++++++++++++++++++ src/lava/proc/sdn/process.py | 5 +- tests/lava/proc/s4d/s4d_A.dat.npy | 3 + tests/lava/proc/s4d/s4d_B.dat.npy | 3 + tests/lava/proc/s4d/s4d_C.dat.npy | 3 + tests/lava/proc/s4d/test_models.py | 195 ++++++++++++++++++++++++++++ tests/lava/proc/s4d/test_process.py | 83 ++++++++++++ tests/lava/proc/s4d/utils.py | 87 +++++++++++++ 9 files changed, 713 insertions(+), 2 deletions(-) create mode 100644 src/lava/proc/s4d/models.py create mode 100644 src/lava/proc/s4d/process.py create mode 100644 tests/lava/proc/s4d/s4d_A.dat.npy create mode 100644 tests/lava/proc/s4d/s4d_B.dat.npy create mode 100644 tests/lava/proc/s4d/s4d_C.dat.npy create mode 100644 tests/lava/proc/s4d/test_models.py create mode 100644 tests/lava/proc/s4d/test_process.py create mode 100644 tests/lava/proc/s4d/utils.py diff --git a/src/lava/proc/s4d/models.py b/src/lava/proc/s4d/models.py new file mode 100644 index 000000000..c409ba3af --- /dev/null +++ b/src/lava/proc/s4d/models.py @@ -0,0 +1,169 @@ +# Copyright (C) 2024 Intel Corporation +# SPDX-License-Identifier: BSD-3-Clause +# See: https://spdx.org/licenses/ + +import numpy as np +from typing import Any, Dict +from lava.proc.sdn.models import AbstractSigmaDeltaModel +from lava.magma.core.decorator import implements, requires, tag +from lava.magma.core.sync.protocols.loihi_protocol import LoihiProtocol +from lava.proc.s4d.process import SigmaS4dDelta, SigmaS4dDeltaLayer +from lava.magma.core.resources import CPU +from lava.magma.core.model.py.ports import PyInPort, PyOutPort +from lava.magma.core.model.py.type import LavaPyType +from lava.magma.core.model.sub.model import AbstractSubProcessModel +from lava.proc.sparse.process import Sparse + + +class AbstractSigmaS4dDeltaModel(AbstractSigmaDeltaModel): + a_in = None + s_out = None + + # SigmaDelta Variables + vth = None + sigma = None + act = None + residue = None + error = None + state_exp = None + bias = None + + # S4 Variables + a = None + b = None + c = None + s4_state = None + s4_exp = None + + def __init__(self, proc_params: Dict[str, Any]) -> None: + """ + Sigma delta neuron model that implements S4D + (as described by Gu et al., 2022) dynamics as its activation function. + + Relevant parameters in proc_params + -------------------------- + a: np.ndarray + Diagonal elements of the state matrix of the S4D model. + b: np.ndarray + Diagonal elements of the input matrix of the S4D model. + c: np.ndarray + Diagonal elements of the output matrix of the S4D model. + s4_state: np.ndarray + State vector of the S4D model. + """ + super().__init__(proc_params) + self.a = self.proc_params['a'] + self.b = self.proc_params['b'] + self.c = self.proc_params['c'] + self.s4_state = self.proc_params['s4_state'] + + def activation_dynamics(self, sigma_data: np.ndarray) -> np.ndarray: + """Sigma Delta activation dynamics. Performs S4D dynamics. + + This function simulates the behavior of a linear time-invariant system + with diagonalized state-space representation. + (For reference see Gu et al., 2022) + + The state-space equations are given by: + s4_state_{k+1} = A * s4_state_k + B * input_k + act_k = C * s4_state_k + + where: + - s4_state_k is the state vector at time step k, + - input_k is the input vector at time step k, + - act_k is the output vector at time step k, + - A is the diagonal state matrix, + - B is the diagonal input matrix, + - C is the diagonal output matrix. + + The function computes the next output step of the + system for the given input signal. + + Parameters + ---------- + sigma_data: np.ndarray + sigma decoded data + + Returns + ------- + np.ndarray + activation output + """ + + self.s4_state = self.s4_state * self.a + sigma_data * self.b + act = self.c * self.s4_state * 2 + return act + + +@implements(proc=SigmaS4dDelta, protocol=LoihiProtocol) +@requires(CPU) +@tag('floating_pt') +class PySigmaS4dDeltaModelFloat(AbstractSigmaS4dDeltaModel): + """Floating point implementation of SigmaS4dDelta neuron.""" + a_in = LavaPyType(PyInPort.VEC_DENSE, float) + s_out = LavaPyType(PyOutPort.VEC_DENSE, float) + + vth: np.ndarray = LavaPyType(np.ndarray, float) + sigma: np.ndarray = LavaPyType(np.ndarray, float) + act: np.ndarray = LavaPyType(np.ndarray, float) + residue: np.ndarray = LavaPyType(np.ndarray, float) + error: np.ndarray = LavaPyType(np.ndarray, float) + bias: np.ndarray = LavaPyType(np.ndarray, float) + + state_exp: np.ndarray = LavaPyType(np.ndarray, np.int32, precision=3) + cum_error: np.ndarray = LavaPyType(np.ndarray, bool, precision=1) + spike_exp: np.ndarray = LavaPyType(np.ndarray, np.int32, precision=3) + s4_exp: np.ndarray = LavaPyType(np.ndarray, np.int32, precision=3) + + # S4 vaiables + s4_state: np.ndarray = LavaPyType(np.ndarray, float) + a: np.ndarray = LavaPyType(np.ndarray, float) + b: np.ndarray = LavaPyType(np.ndarray, float) + c: np.ndarray = LavaPyType(np.ndarray, float) + + def run_spk(self) -> None: + # Receive synaptic input + a_in_data = self.a_in.recv() + s_out = self.dynamics(a_in_data) + self.s_out.send(s_out) + + +@implements(proc=SigmaS4dDeltaLayer, protocol=LoihiProtocol) +class SubDenseLayerModel(AbstractSubProcessModel): + def __init__(self, proc): + """Builds (Sparse -> S4D -> Sparse) connection of the process.""" + conn_weights = proc.proc_params.get("conn_weights") + shape = proc.proc_params.get("shape") + state_exp = proc.proc_params.get("state_exp") + num_message_bits = proc.proc_params.get("num_message_bits") + s4_exp = proc.proc_params.get("s4_exp") + d_states = proc.proc_params.get("d_states") + a = proc.proc_params.get("a") + b = proc.proc_params.get("b") + c = proc.proc_params.get("c") + vth = proc.proc_params.get("vth") + + # Instantiate processes + self.sparse1 = Sparse(weights=conn_weights.T, weight_exp=state_exp, + num_message_bits=num_message_bits) + self.sigma_S4d_delta = SigmaS4dDelta(shape=(shape[0] * d_states,), + vth=vth, + state_exp=state_exp, + s4_exp=s4_exp, + a=a, + b=b, + c=c) + self.sparse2 = Sparse(weights=conn_weights, weight_exp=state_exp, + num_message_bits=num_message_bits) + + # Make connections Sparse -> SigmaS4Delta -> Sparse + proc.in_ports.s_in.connect(self.sparse1.in_ports.s_in) + self.sparse1.out_ports.a_out.connect(self.sigma_S4d_delta.in_ports.a_in) + self.sigma_S4d_delta.out_ports.s_out.connect(self.sparse2.s_in) + self.sparse2.out_ports.a_out.connect(proc.out_ports.a_out) + + # Set aliases + proc.vars.a.alias(self.sigma_S4d_delta.vars.a) + proc.vars.b.alias(self.sigma_S4d_delta.vars.b) + proc.vars.c.alias(self.sigma_S4d_delta.vars.c) + proc.vars.s4_state.alias(self.sigma_S4d_delta.vars.s4_state) diff --git a/src/lava/proc/s4d/process.py b/src/lava/proc/s4d/process.py new file mode 100644 index 000000000..218e7292c --- /dev/null +++ b/src/lava/proc/s4d/process.py @@ -0,0 +1,167 @@ +# Copyright (C) 2024 Intel Corporation +# SPDX-License-Identifier: BSD-3-Clause +# See: https://spdx.org/licenses/ + +import typing as ty +import numpy as np +from lava.magma.core.process.process import AbstractProcess +from lava.magma.core.process.variable import Var +from lava.magma.core.process.ports.ports import InPort, OutPort +from lava.proc.sdn.process import ActivationMode, SigmaDelta + + +class SigmaS4dDelta(SigmaDelta, AbstractProcess): + def __init__( + self, + shape: ty.Tuple[int, ...], + vth: ty.Union[int, float], + a: float, + b: float, + c: float, + state_exp: ty.Optional[int] = 0, + s4_exp: ty.Optional[int] = 0) -> None: + """ + Sigma delta neuron process that implements S4D (described by + Gu et al., 2022) dynamics as its activation function. + + This process simulates the behavior of a linear time-invariant system + with diagonal state-space representation. + The state-space equations are given by: + s4_state_{k+1} = A * s4_state_k + B * inp_k + act_k = C * s4_state_k + + where: + - s4_state_k is the state vector at time step k, + - inp_k is the input vector at time step k, + - act_k is the output vector at time step k, + - A is the diagonal state matrix, + - B is the diagonal input matrix, + - C is the diagonal output matrix. + + Parameters + ---------- + shape: Tuple + Shape of the sigma process. + vth: int or float + Threshold of the delta encoder. + a: np.ndarray + Diagonal elements of the state matrix of the S4D model. + b: np.ndarray + Diagonal elements of the input matrix of the S4D model. + c: np.ndarray + Diagonal elements of the output matrix of the S4D model. + state_exp: int + Scaling exponent with base 2 for the reconstructed sigma variables. + Note: This should only be used for nc models. + Default is 0. + s4_exp: int + Scaling exponent with base 2 for the S4 state variables. + Note: This should only be used for nc models. + Default is 0. + """ + + super().__init__(shape=shape, + vth=vth, + a=a, + b=b, + c=c, + s4_state=0, + state_exp=state_exp, + s4_exp=s4_exp) + + # Variables for S4 + self.a = Var(shape=shape, init=a) + self.b = Var(shape=shape, init=b) + self.c = Var(shape=shape, init=c) + self.s4_state = Var(shape=shape, init=0) + self.s4_exp = Var(shape=(1,), init=s4_exp) + + +class SigmaS4dDeltaLayer(AbstractProcess): + def __init__( + self, + shape: ty.Tuple[int, ...], + vth: ty.Union[int, float], + a: float, + b: float, + c: float, + d_states: ty.Optional[int] = 1, + s4_exp: ty.Optional[int] = 0, + num_message_bits: ty.Optional[int] = 24, + state_exp: ty.Optional[int] = 0) -> None: + """ + Combines S4D neuron with Sparse Processes that allow for multiple + d_states. + + Connectivity: Sparse -> SigmaS4dDelta -> Sparse. + Relieves user from computing required connection weights for multiple + d_states. + + Parameters + ---------- + shape: Tuple + Shape of the sigma process. + vth: int or float + Threshold of the delta encoder. + a: np.ndarray + Diagonal elements of the state matrix of the S4D model. + b: np.ndarray + Diagonal elements of the input matrix of the S4D model. + c: np.ndarray + Diagonal elements of the output matrix of the S4D model. + d_states: int + Number of hidden states of the S4D model. + Default is 1. + state_exp: int + Scaling exponent with base 2 for the reconstructed sigma variables. + Note: Only relevant for nc model. + Default is 0. + num_message_bits: int + Number of message bits to be used in Sparse connection processes. + Note: Only relevant for nc model. + s4_exp: int + Scaling exponent with base 2 for the S4 state variables. + Note: Only relevant for nc model. + Default is 0. + """ + + # Automatically takes care of expansion and reduction of dimensionality + # for multiple hidden states (d_states) + conn_weights = np.kron(np.eye(shape[0]), np.ones(d_states)) + s4_state = 0 + super().__init__(shape=shape, + vth=vth, + a=a, + b=b, + c=c, + s4_exp=s4_exp, + s4_state=s4_state, + conn_weights=conn_weights, + num_message_bits=num_message_bits, + d_states=d_states, + state_exp=state_exp, + act_mode=ActivationMode.UNIT) + + # Ports + self.s_in = InPort(shape=shape) + self.a_out = OutPort(shape=shape) + + # General variables + self.state_exp = Var(shape=(1,), init=state_exp) + + # Variables for S4 + self.a = Var(shape=(shape[0] * d_states,), init=a) + self.b = Var(shape=(shape[0] * d_states,), init=b) + self.c = Var(shape=(shape[0] * d_states,), init=c) + self.s4_state = Var(shape=(shape[0] * d_states,), init=0) + self.S4_exp = Var(shape=(1,), init=s4_exp) + + # Variables for connecting Dense processes + # Project input_dim to input_dim * d_states + self.conn_weights = Var(shape=shape, init=conn_weights) + self.num_message_bits = Var(shape=(1,), init=num_message_bits) + + @property + def shape(self) -> ty.Tuple[int, ...]: + """Return shape of the Process.""" + return self.proc_params['shape'] diff --git a/src/lava/proc/sdn/process.py b/src/lava/proc/sdn/process.py index 32f083247..6de494706 100644 --- a/src/lava/proc/sdn/process.py +++ b/src/lava/proc/sdn/process.py @@ -126,7 +126,8 @@ def __init__( act_mode: ty.Optional[ActivationMode] = ActivationMode.RELU, cum_error: ty.Optional[bool] = False, spike_exp: ty.Optional[int] = 0, - state_exp: ty.Optional[int] = 0) -> None: + state_exp: ty.Optional[int] = 0, + **kwargs) -> None: """Sigma delta neuron process. At the moment only ReLu activation is supported. Spike mechanism based on accumulated error is also supported. @@ -173,7 +174,7 @@ def __init__( """ super().__init__(shape=shape, vth=vth, bias=bias, act_mode=act_mode, cum_error=cum_error, - spike_exp=spike_exp, state_exp=state_exp) + spike_exp=spike_exp, state_exp=state_exp, **kwargs) # scaling factor for fixed precision scaling vth = vth * (1 << (spike_exp + state_exp)) bias = bias * (1 << (spike_exp + state_exp)) diff --git a/tests/lava/proc/s4d/s4d_A.dat.npy b/tests/lava/proc/s4d/s4d_A.dat.npy new file mode 100644 index 000000000..081fdd6ec --- /dev/null +++ b/tests/lava/proc/s4d/s4d_A.dat.npy @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:9e4c9f5f11d3139b86ccdfe3d8e2179566dc8ac8da31a4a23951dba174425663 +size 5248 diff --git a/tests/lava/proc/s4d/s4d_B.dat.npy b/tests/lava/proc/s4d/s4d_B.dat.npy new file mode 100644 index 000000000..916b75ffb --- /dev/null +++ b/tests/lava/proc/s4d/s4d_B.dat.npy @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:4b67f99c0a172c862abfc65b3aabeba9bc91fe1f2254d0df066d19c9b3e3b8fe +size 5248 diff --git a/tests/lava/proc/s4d/s4d_C.dat.npy b/tests/lava/proc/s4d/s4d_C.dat.npy new file mode 100644 index 000000000..910f44a96 --- /dev/null +++ b/tests/lava/proc/s4d/s4d_C.dat.npy @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:055720b65a2eb0bf0043989b1a078cc028f7a105a0cb394ba03cdbf3adac8ac1 +size 5248 diff --git a/tests/lava/proc/s4d/test_models.py b/tests/lava/proc/s4d/test_models.py new file mode 100644 index 000000000..c08e19cf8 --- /dev/null +++ b/tests/lava/proc/s4d/test_models.py @@ -0,0 +1,195 @@ +# Copyright (C) 2024 Intel Corporation +# SPDX-License-Identifier: BSD-3-Clause +# See: https://spdx.org/licenses/ + +import unittest +import numpy as np +from typing import Tuple +import lava.proc.io as io +from lava.magma.core.run_conditions import RunSteps +from lava.proc.sdn.process import ActivationMode, SigmaDelta +from lava.proc.s4d.process import SigmaS4dDelta, SigmaS4dDeltaLayer +from lava.proc.sparse.process import Sparse +from lava.magma.core.run_configs import Loihi2SimCfg +from tests.lava.proc.s4d.utils import get_coefficients, run_original_model + + +class TestSigmaS4DDeltaModels(unittest.TestCase): + """Tests for SigmaS4Delta neuron""" + def run_in_lava( + self, + inp, + a: np.ndarray, + b: np.ndarray, + c: np.ndarray, + num_steps: int, + model_dim: int, + d_states: int, + use_layer: bool) -> Tuple[np.ndarray]: + + """ Run S4d model in lava. + + Parameters + ---------- + inp : np.ndarray + Input signal to the model. + num_steps : int + Number of time steps to simulate the model. + model_dim : int + Dimensionality of the model. + d_states : int + Number of model states. + use_layer : bool + Whether to use the layer implementation of the model + (SigmaS4DeltaLayer, helpful for multiple d_states) or just + the neuron model (SigmaS4Delta). + + Returns + ------- + Tuple[np.ndarray] + Tuple containing the output of the model simulation. + """ + + a = a[:model_dim * d_states] + b = b[:model_dim * d_states] + c = c[:model_dim * d_states] + + diff = inp[:, 1:] - inp[:, :-1] + diff = np.concatenate((inp[:, :1], diff), axis=1) + + spiker = io.source.RingBuffer(data=diff) + receiver = io.sink.RingBuffer(shape=(model_dim,), buffer=num_steps) + + if use_layer: + s4d_layer = SigmaS4dDeltaLayer(shape=(model_dim,), + d_states=d_states, + num_message_bits=24, + vth=0, + a=a, + b=b, + c=c) + buffer_neuron = SigmaDelta(shape=(model_dim,), + vth=0, + cum_error=True, + act_mode=ActivationMode.UNIT) + spiker.s_out.connect(s4d_layer.s_in) + s4d_layer.a_out.connect(buffer_neuron.a_in) + buffer_neuron.s_out.connect(receiver.a_in) + + else: + sparse = Sparse(weights=np.eye(model_dim), num_message_bits=24) + s4d_neuron = SigmaS4dDelta(shape=((model_dim,)), + vth=0, + a=a, + b=b, + c=c) + spiker.s_out.connect(sparse.s_in) + sparse.a_out.connect(s4d_neuron.a_in) + s4d_neuron.s_out.connect(receiver.a_in) + + run_condition = RunSteps(num_steps=num_steps) + run_config = Loihi2SimCfg() + + spiker.run(condition=run_condition, run_cfg=run_config) + output = receiver.data.get() + spiker.stop() + + output = np.cumsum(output, axis=1) + + return output + + def test_py_model_vs_original_equations(self) -> None: + """Tests that the pymodel for SigmaS4dDelta outputs approximately + the same values as the original S4D equations. + """ + a, b, c = get_coefficients() + model_dim = 3 + d_states = 1 + n_steps = 5 + np.random.seed(0) + inp = np.random.random((model_dim, n_steps)) * 2**6 + + out_chip = self.run_in_lava(inp=inp, + a=a, + b=b, + c=c, + num_steps=n_steps, + model_dim=model_dim, + d_states=d_states, + use_layer=False + ) + out_original_model = run_original_model(inp=inp, + model_dim=model_dim, + d_states=d_states, + num_steps=n_steps, + a=a, + b=b, + c=c) + + np.testing.assert_array_equal(out_original_model[:, :-1], + out_chip[:, 1:]) + + def test_py_model_layer_vs_original_equations(self) -> None: + """ Tests that the pymodel for SigmaS4DeltaLayer outputs approximately + the same values as the original S4D equations for multiple d_states. + """ + a, b, c = get_coefficients() + model_dim = 3 + d_states = 3 + n_steps = 5 + np.random.seed(1) + inp = np.random.random((model_dim, n_steps)) * 2**6 + + out_chip = self.run_in_lava(inp=inp, + a=a, + b=b, + c=c, + num_steps=n_steps, + model_dim=model_dim, + d_states=d_states, + use_layer=True, + ) + out_original_model = run_original_model(inp=inp, + model_dim=model_dim, + d_states=d_states, + num_steps=n_steps, + a=a, + b=b, + c=c) + + np.testing.assert_allclose(out_original_model[:, :-2], out_chip[:, 2:]) + + def test_py_model_vs_py_model_layer(self) -> None: + """Tests that the pymodel for SigmaS4DeltaLayer outputs approximately + the same values as just the SigmaS4DDelta Model with one hidden dim. + """ + a, b, c = get_coefficients() + model_dim = 3 + d_states = 1 + n_steps = 5 + np.random.seed(2) + inp = np.random.random((model_dim, n_steps)) * 2**6 + + out_just_model = self.run_in_lava(inp=inp, + a=a, + b=b, + c=c, + num_steps=n_steps, + model_dim=model_dim, + d_states=d_states, + use_layer=False) + + out_layer = self.run_in_lava(inp=inp, + a=a, + b=b, + c=c, + num_steps=n_steps, + model_dim=model_dim, + d_states=d_states, + use_layer=True) + + np.testing.assert_allclose(out_layer[:, 1:], out_just_model[:, :-1]) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/lava/proc/s4d/test_process.py b/tests/lava/proc/s4d/test_process.py new file mode 100644 index 000000000..5f6cbcc9c --- /dev/null +++ b/tests/lava/proc/s4d/test_process.py @@ -0,0 +1,83 @@ +# Copyright (C) 2024 Intel Corporation +# SPDX-License-Identifier: BSD-3-Clause +# See: https://spdx.org/licenses/ + +import unittest +import numpy as np +from lava.proc.s4d.process import SigmaS4dDelta, SigmaS4dDeltaLayer + + +class TestSigmaS4dDeltaProcess(unittest.TestCase): + """Tests for SigmaS4dDelta Class""" + + def test_init(self) -> None: + """Tests instantiation of SigmaS4dDelta""" + shape = 10 + vth = 10 + state_exp = 6 + s4_exp = 12 + a = np.ones(shape) * 0.5 + b = np.ones(shape) * 0.8 + c = np.ones(shape) * 0.9 + sigma_s4_delta = SigmaS4dDelta(shape=(shape,), + vth=vth, + state_exp=state_exp, + s4_exp=s4_exp, + a=a, + b=b, + c=c) + + # determined by user - S4 part + self.assertEqual(sigma_s4_delta.shape, (shape,)) + self.assertEqual(sigma_s4_delta .vth.init, vth * 2 ** state_exp) + self.assertEqual(sigma_s4_delta.s4_exp.init, s4_exp) + np.testing.assert_array_equal(sigma_s4_delta.a.init, a) + np.testing.assert_array_equal(sigma_s4_delta.b.init, b) + np.testing.assert_array_equal(sigma_s4_delta.c.init, c) + self.assertEqual(sigma_s4_delta.state_exp.init, state_exp) + self.assertEqual(sigma_s4_delta.s4_state.init, 0) + + # default sigmadelta params - inherited from SigmaDelta class + self.assertEqual(sigma_s4_delta.cum_error.init, False) + self.assertEqual(sigma_s4_delta.spike_exp.init, 0) + self.assertEqual(sigma_s4_delta.bias.init, 0) + + +class TestSigmaS4DeltaLayer(unittest.TestCase): + """Tests for SigmaS4dDeltaLayer Class""" + + def test_init(self) -> None: + """Tests instantiation of SigmaS4dDeltaLayer """ + shape = 10 + vth = 10 + state_exp = 6 + s4_exp = 12 + d_states = 5 + a = np.ones(shape) * 0.5 + b = np.ones(shape) * 0.8 + c = np.ones(shape) * 0.9 + + sigma_s4d_delta_layer = SigmaS4dDeltaLayer(shape=(shape,), + d_states=d_states, + vth=vth, + state_exp=state_exp, + s4_exp=s4_exp, + a=a, + b=b, + c=c) + # determined by user - S4 part + self.assertEqual(sigma_s4d_delta_layer.shape, (shape,)) + self.assertEqual(sigma_s4d_delta_layer.S4_exp.init, s4_exp) + np.testing.assert_array_equal(sigma_s4d_delta_layer.a.init, a) + np.testing.assert_array_equal(sigma_s4d_delta_layer.b.init, b) + np.testing.assert_array_equal(sigma_s4d_delta_layer.c.init, c) + self.assertEqual(sigma_s4d_delta_layer.state_exp.init, state_exp) + self.assertEqual(sigma_s4d_delta_layer.s4_state.init, 0) + + # determined by user/via number of states and shape + np.testing.assert_array_equal(sigma_s4d_delta_layer.conn_weights.init, + np.kron(np.eye(shape), np.ones(d_states))) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/lava/proc/s4d/utils.py b/tests/lava/proc/s4d/utils.py new file mode 100644 index 000000000..3afd48b2e --- /dev/null +++ b/tests/lava/proc/s4d/utils.py @@ -0,0 +1,87 @@ +# Copyright (C) 2024 Intel Corporation +# SPDX-License-Identifier: BSD-3-Clause +# See: https://spdx.org/licenses/ + +import os +import numpy as np +from typing import Tuple + + +def get_coefficients() -> [np.ndarray, np.ndarray, np.ndarray]: + curr_dir = os.path.dirname(os.path.realpath(__file__)) + + # Initialize A, B and C with values trained on efficientnet features. + s4d_A = np.load(curr_dir + "/s4d_A.dat.npy").flatten() + s4d_B = np.load(curr_dir + "/s4d_B.dat.npy").flatten() + s4d_C = np.load(curr_dir + "/s4d_C.dat.npy").flatten().flatten() + return s4d_A, s4d_B, s4d_C + + +def run_original_model( + inp: np.ndarray, + num_steps: int, + model_dim: int, + d_states: int, + a: np.ndarray, + b: np.ndarray, + c: np.ndarray) -> Tuple[np.ndarray]: + """ + Run original S4d model. + + This function simulates the behavior of a linear time-invariant system + with diagonalized state-space representation. (S4D) + The state-space equations are given by: + s4_state_{k+1} = A * s4_state_k + B * input_k + out_k = C * s4_state_k + + where: + - s4_state_k is the state vector at time step k, + - input_k is the input vector at time step k, + - out_k is the output vector at time step k, + - A is the diagonal state matrix, + - B is the diagonal input matrix, + - C is the diagonal output matrix. + + The function computes the next output step of the + system for the given input signal. + + The function computes the output of the system for the given input signal + over num_steps time steps. + + Parameters + ---------- + input: np.ndarray + Input signal to the model. + num_steps: int + Number of time steps to simulate the model. + model_dim: int + Dimensionality of the model. + d_states: int + Number of model states. + a: np.ndarray + Diagonal elements of the state matrix of the S4D model. + b: np.ndarray + Diagonal elements of the input matrix of the S4D model. + c: np.ndarray + Diagonal elements of the output matrix of the S4D model. + + Returns + ------- + Tuple[np.ndarray] + Tuple containing the output of the model simulation. + """ + + a = a[:model_dim * d_states] + b = b[:model_dim * d_states] + c = c[:model_dim * d_states] + expansion_weights = np.kron(np.eye(model_dim), np.ones(d_states)) + expanded_inp = np.matmul(expansion_weights.T, inp) + out = np.zeros((model_dim * d_states, num_steps)) + s4_state = np.zeros((model_dim * d_states,)).flatten() + + for idx, inp in enumerate(expanded_inp.T): + s4_state = np.multiply(s4_state, a) + np.multiply(inp, b) + out[:, idx] = np.multiply(c, s4_state) * 2 + + out = np.matmul(expansion_weights, out) + return out