Skip to content

Commit

Permalink
Adding current-based LIF neurons and example Jupyter notebook.
Browse files Browse the repository at this point in the history
  • Loading branch information
Dan Saunders committed Jun 11, 2018
1 parent 48f3c2b commit 7895b44
Show file tree
Hide file tree
Showing 3 changed files with 342 additions and 31 deletions.
134 changes: 104 additions & 30 deletions bindsnet/network/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def step(self, inpts, dt):
if self.traces:
# Decay and set spike traces.
self.x -= dt * self.trace_tc * self.x
self.x = self.x.masked_fill(self.s, 1)
self.x.masked_fill_(self.s, 1)

@abstractmethod
def _reset(self):
Expand Down Expand Up @@ -179,11 +179,11 @@ def step(self, inpts, dt):
self.refrac_count[self.refrac_count != 0] -= dt

# Check for spiking neurons.
self.s = (self.v >= self.thresh) * (self.refrac_count == 0)
self.s = (self.v >= self.thresh) & (self.refrac_count == 0)

# Refractoriness and voltage reset.
self.refrac_count = self.refrac_count.masked_fill(self.s, self.refrac)
self.v = self.v.masked_fill(self.s, self.reset)
self.refrac_count.masked_fill_(self.s, self.refrac)
self.v.masked_fill_(self.s, self.reset)

# Integrate input and decay voltages.
self.v += inpts
Expand Down Expand Up @@ -222,14 +222,14 @@ def __init__(self, n=None, shape=None, traces=False, thresh=-52.0, rest=-65.0,
'''
super().__init__(n, shape, traces, trace_tc)

self.rest = rest # Rest voltage.
self.reset = reset # Post-spike reset voltage.
self.thresh = thresh # Spike threshold voltage.
self.refrac = refrac # Post-spike refractory period.
self.decay = decay # Rate of decay of neuron voltage.
self.rest = rest # Rest voltage.
self.reset = reset # Post-spike reset voltage.
self.thresh = thresh # Spike threshold voltage.
self.refrac = refrac # Post-spike refractory period.
self.decay = decay # Rate of decay of neuron voltage.

self.v = torch.zeros(self.shape) + self.rest # Neuron voltages.
self.refrac_count = torch.zeros(self.shape) # Refractory period counters.
self.v = self.rest * torch.ones(self.shape) # Neuron voltages.
self.refrac_count = torch.zeros(self.shape) # Refractory period counters.

def step(self, inpts, dt):
'''
Expand All @@ -247,11 +247,11 @@ def step(self, inpts, dt):
self.refrac_count[self.refrac_count != 0] -= dt

# Check for spiking neurons.
self.s = (self.v >= self.thresh) * (self.refrac_count == 0)
self.s = (self.v >= self.thresh) & (self.refrac_count == 0)

# Refractoriness and voltage reset.
self.refrac_count = self.refrac_count.masked_fill(self.s, self.refrac)
self.v = self.v.masked_fill(self.s, self.reset)
self.refrac_count.masked_fill_(self.s, self.refrac)
self.v.masked_fill_(self.s, self.reset)

# Integrate inputs.
self.v += inpts
Expand All @@ -267,6 +267,80 @@ def _reset(self):
self.refrac_count = torch.zeros(self.shape) # Refractory period counters.


class CurrentLIFNodes(Nodes):
'''
Layer of current-based leaky integrate-and-fire (LIF) neurons.
'''
def __init__(self, n=None, shape=None, traces=False, thresh=-52.0, rest=-65.0,
reset=-65.0, refrac=5, decay=1e-2, i_decay=2e-2, trace_tc=5e-2):
'''
Instantiates a layer of synaptic input current-based LIF neurons.
Inputs:
| :code:`n` (:code:`int`): The number of neurons in the layer.
| :code:`shape` (:code:`iterable[int]`): The dimensionality of the layer.
| :code:`traces` (:code:`bool`): Whether to record spike traces.
| :code:`thresh` (:code:`float`): Spike threshold voltage.
| :code:`rest` (:code:`float`): Resting membrane voltage.
| :code:`reset` (:code:`float`): Post-spike reset voltage.
| :code:`refrac` (:code:`int`): Refractory (non-firing) period of the neuron.
| :code:`decay` (:code:`float`): Time constant of neuron voltage decay.
| :code:`i_decay` (:code:`float`): Time constant of synaptic input current decay.
| :code:`trace_tc` (:code:`float`): Time constant of spike trace decay.
'''
super().__init__(n, shape, traces, trace_tc)

self.rest = rest # Rest voltage.
self.reset = reset # Post-spike reset voltage.
self.thresh = thresh # Spike threshold voltage.
self.refrac = refrac # Post-spike refractory period.
self.decay = decay # Rate of decay of neuron voltage.
self.i_decay = i_decay # Rate of decay of synaptic input current.

self.v = self.rest * torch.ones(self.shape) # Neuron voltages.
self.i = torch.zeros(self.shape) # Synaptic input currents.
self.refrac_count = torch.zeros(self.shape) # Refractory period counters.

def step(self, inpts, dt):
'''
Runs a single simulation step.
Inputs:
| :code:`inpts` (:code:`torch.Tensor`): Inputs to the layer.
| :code:`dt` (:code:`float`): Simulation time step.
'''
# Decay voltages and current.
self.v -= dt * self.decay * (self.v - self.rest)
self.i -= dt * self.i_decay * self.i

# Decrement refrac counters.
self.refrac_count[self.refrac_count != 0] -= dt

# Check for spiking neurons.
self.s = (self.v >= self.thresh) & (self.refrac_count == 0)

# Refractoriness and voltage reset.
self.refrac_count.masked_fill_(self.s, self.refrac)
self.v.masked_fill_(self.s, self.reset)

# Integrate inputs.
self.i += inpts
self.v += self.i

super().step(inpts, dt)

def _reset(self):
'''
Resets relevant state variables.
'''
super()._reset()
self.v = self.rest * torch.ones(self.shape) # Neuron voltages.
self.i = torch.zeros(self.shape) # Synaptic input currents.
self.refrac_count = torch.zeros(self.shape) # Refractory period counters.


class AdaptiveLIFNodes(Nodes):
'''
Layer of leaky integrate-and-fire (LIF) neurons with adaptive thresholds.
Expand Down Expand Up @@ -296,7 +370,7 @@ def __init__(self, n=None, shape=None, traces=False, rest=-65.0, reset=-65.0, th
self.reset = reset # Post-spike reset voltage.
self.thresh = thresh # Spike threshold voltage.
self.refrac = refrac # Post-spike refractory period.
self.decay = decay # Rate of decay of neuron voltage.
self.decay = decay # Rate of decay of neuron voltage.
self.theta_plus = theta_plus # Constant threshold increase on spike.
self.theta_decay = theta_decay # Rate of decay of adaptive thresholds.

Expand All @@ -322,11 +396,11 @@ def step(self, inpts, dt):
self.refrac_count[self.refrac_count != 0] -= dt

# Check for spiking neurons.
self.s = (self.v >= self.thresh + self.theta) * (self.refrac_count == 0)
self.s = (self.v >= self.thresh + self.theta) & (self.refrac_count == 0)

# Refractoriness, voltage reset, and adaptive thresholds.
self.refrac_count = self.refrac_count.masked_fill(self.s, self.refrac)
self.v = self.v.masked_fill(self.s, self.reset)
self.refrac_count.masked_fill_(self.s, self.refrac)
self.v.masked_fill_(self.s, self.reset)
self.theta += self.theta_plus * self.s.float()

# Integrate inputs.
Expand Down Expand Up @@ -372,7 +446,7 @@ def __init__(self, n=None, shape=None, traces=False, rest=-65.0, reset=-65.0, th
self.reset = reset # Post-spike reset voltage.
self.thresh = thresh # Spike threshold voltage.
self.refrac = refrac # Post-spike refractory period.
self.decay = decay # Rate of decay of neuron voltage.
self.decay = decay # Rate of decay of neuron voltage.
self.theta_plus = theta_plus # Constant threshold increase on spike.
self.theta_decay = theta_decay # Rate of decay of adaptive thresholds.

Expand All @@ -397,11 +471,11 @@ def step(self, inpts, dt):
self.refrac_count[self.refrac_count != 0] -= dt

# Check for spiking neurons.
self.s = (self.v >= self.thresh + self.theta) * (self.refrac_count == 0)
self.s = (self.v >= self.thresh + self.theta) & (self.refrac_count == 0)

# Refractoriness, voltage reset, and adaptive thresholds.
self.refrac_count = self.refrac_count.masked_fill(self.s, self.refrac)
self.v = self.v.masked_fill(self.s, self.reset)
self.refrac_count.masked_fill_(self.s, self.refrac)
self.v.masked_fill_(self.s, self.reset)
self.theta += self.theta_plus * self.s.float()

# Choose only a single neuron to spike.
Expand Down Expand Up @@ -449,11 +523,11 @@ def __init__(self, n=None, shape=None, traces=False, excitatory=True, rest=-65.0
'''
super().__init__(n, shape, traces, trace_tc)

self.rest = rest # Rest voltage.
self.reset = reset # Post-spike reset voltage.
self.thresh = thresh # Spike threshold voltage.
self.refrac = refrac # Post-spike refractory period.
self.decay = decay # Rate of decay of neuron voltage.
self.rest = rest # Rest voltage.
self.reset = reset # Post-spike reset voltage.
self.thresh = thresh # Spike threshold voltage.
self.refrac = refrac # Post-spike refractory period.
self.decay = decay # Rate of decay of neuron voltage.

if excitatory:
self.r = torch.rand(n)
Expand Down Expand Up @@ -485,11 +559,11 @@ def step(self, inpts, dt):
self.refrac_count[self.refrac_count != 0] -= dt

# Check for spiking neurons.
self.s = (self.v >= self.thresh) * (self.refrac_count == 0)
self.s = (self.v >= self.thresh) & (self.refrac_count == 0)

# Refractoriness and voltage reset.
self.refrac_count = self.refrac_count.masked_fill(self.s, self.refrac)
self.v = self.v.masked_fill(self.s, self.reset)
self.refrac_count.masked_fill_(self.s, self.refrac)
self.v.masked_fill_(self.s, self.reset)

# Apply v and u updates.
self.v += dt * (0.04 * (self.v ** 2) + 5 * self.v + 140 - self.u + inpts)
Expand Down
5 changes: 4 additions & 1 deletion bindsnet/network/topology.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,10 @@ def __init__(self, source, target, nu=1e-2, nu_pre=1e-4, nu_post=1e-2, **kwargs)
self.w = kwargs.get('w', None)

if self.w is None:
self.w = self.wmin + torch.rand(*source.shape, *target.shape) * (self.wmin - self.wmin)
if self.wmin == -np.inf or self.wmax == np.inf:
self.w = torch.rand(*source.shape, *target.shape)
else:
self.w = self.wmin + torch.rand(*source.shape, *target.shape) * (self.wmin - self.wmin)
else:
if torch.max(self.w) > self.wmax or torch.min(self.w) < self.wmin:
warnings.warn('Weight matrix will be clamped between [%f, %f]; values may be biased to interval values.' % (self.wmin, self.wmax))
Expand Down
234 changes: 234 additions & 0 deletions examples/notebooks/LIFNodes vs. CurrentLIFNodes.ipynb

Large diffs are not rendered by default.

0 comments on commit 7895b44

Please sign in to comment.