Skip to content

Commit

Permalink
Merge pull request #19 from project-rig/fix-#18
Browse files Browse the repository at this point in the history
Apply pre-slice before connection functions
  • Loading branch information
tcstewar committed Jun 16, 2015
2 parents 01179c2 + 8a0ebf7 commit dfb358a
Show file tree
Hide file tree
Showing 5 changed files with 80 additions and 7 deletions.
3 changes: 1 addition & 2 deletions nengo_spinnaker/builder/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,9 @@ def generic_sink_getter(model, conn):

@Model.connection_parameter_builders.register(nengo.base.NengoObject)
def build_generic_connection_params(model, conn):
transform = full_transform(conn)
return BuiltConnection(
decoders=None,
transform=transform,
transform=full_transform(conn, slice_pre=False),
eval_points=None,
solver_info=None
)
4 changes: 2 additions & 2 deletions nengo_spinnaker/node_io/ethernet.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,8 +111,8 @@ def set_node_output(self, node, value):
# Build an SDP packet to transmit for each outgoing connection for the
# node
for connection, (x, y, p) in self._node_outgoing[node]:
# Perform connection function and transform
c_value = value[:]
# Apply the pre-slice, the connection function and the transform.
c_value = value[connection.pre_slice]
if connection.function is not None:
c_value = connection.function(c_value)
c_value = np.dot(connection.transform, c_value)
Expand Down
20 changes: 18 additions & 2 deletions nengo_spinnaker/operators/value_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import struct

from nengo.processes import Process
from nengo.utils import numpy as npext

from nengo_spinnaker.builder.builder import OutputPort, netlistspec
from nengo_spinnaker.netlist import VertexSlice
Expand Down Expand Up @@ -50,9 +51,9 @@ def make_vertices(self, model, n_steps):

# Add the keys for this connection
conn = conns[0]
so = conns[0].size_out
keys.extend(list(
get_derived_keyspaces(sig.keyspace, slice(0, so))
get_derived_keyspaces(sig.keyspace, conn.post_slice,
max_v=conn.post_obj.size_in)
))
self.conns.append(conn)
size_out = len(keys)
Expand Down Expand Up @@ -140,16 +141,31 @@ def before_simulation(self, netlist, simulator, n_steps):
else:
values = np.array([self.function for t in ts])

# Ensure that the values can be sliced, regardless of how they were
# generated.
values = npext.array(values, min_dims=2)

# Compute the output for each connection
outputs = []
for conn in self.conns:
output = []

# For each f(t) for the next set of simulations we calculate the
# output at the end of the connection. To do this we first apply
# the pre-slice, then the function and then the post-slice.
for v in values:
# Apply the pre-slice
v = v[conn.pre_slice]

# Apply the function on the connection, if there is one.
if conn.function is not None:
v = conn.function(v)

output.append(np.dot(conn.transform, v.T))
outputs.append(np.array(output).reshape(n_steps, conn.size_out))

# Combine all of the output values to form a large matrix which we can
# dump into memory.
output_matrix = np.hstack(outputs)

new_output_region = regions.MatrixRegion(
Expand Down
58 changes: 58 additions & 0 deletions regression-tests/test_nodes_sliced.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
"""More complex function of time Node example.
"""
import nengo
import nengo_spinnaker
import numpy as np
import pytest


@pytest.mark.parametrize("f_of_t", [True, False])
def test_nodes_sliced(f_of_t):
# Create a model with a single function of time node which returns a 4D
# vector, apply preslicing on some connections from it and ensure that this
# slicing plays nicely with the functions attached to the connections.
def out_fun_1(val):
assert val.size == 2
return val * 2

with nengo.Network() as model:
# Create the input node and an ensemble
in_node = nengo.Node(lambda t: [0.1, 1.0, 0.2, -1.0], size_out=4)
in_node_2 = nengo.Node(0.25)

ens = nengo.Ensemble(400, 4)
ens2 = nengo.Ensemble(200, 2)

# Create the connections
nengo.Connection(in_node[::2], ens[[1, 3]], transform=.5,
function=out_fun_1)
nengo.Connection(in_node_2[[0, 0]], ens2)

# Probe the ensemble to ensure that the values are correct
p = nengo.Probe(ens, synapse=0.05)
p2 = nengo.Probe(ens2, synapse=0.05)

# Mark the input as being a function of time if desired
if f_of_t:
nengo_spinnaker.add_spinnaker_params(model.config)
model.config[in_node].function_of_time = True

# Run the simulator for 1.0 s and check that the last probed values are in
# range
sim = nengo_spinnaker.Simulator(model)
with sim:
sim.run(1.0)

# Check the final values
assert -0.05 < sim.data[p][-1, 0] < 0.05
assert 0.05 < sim.data[p][-1, 1] < 0.15
assert -0.05 < sim.data[p][-1, 2] < 0.05
assert 0.15 < sim.data[p][-1, 3] < 0.25

assert 0.20 < sim.data[p2][-1, 0] < 0.30
assert 0.20 < sim.data[p2][-1, 1] < 0.30


if __name__ == "__main__":
test_nodes_sliced(True)
test_nodes_sliced(False)
2 changes: 1 addition & 1 deletion tests/builder/test_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,6 @@ def test_build_standard_connection_params():
# Build the connection parameters
params = build_generic_connection_params(None, a_b)
assert params.decoders is None
assert np.all(params.transform == [[1.0, 0.0]])
assert params.transform == 1.0
assert params.eval_points is None
assert params.solver_info is None

0 comments on commit dfb358a

Please sign in to comment.