diff --git a/bindsnet/analysis/plotting.py b/bindsnet/analysis/plotting.py index 6e2380bb..23f2b6d1 100644 --- a/bindsnet/analysis/plotting.py +++ b/bindsnet/analysis/plotting.py @@ -55,11 +55,14 @@ def plot_spikes(spikes, ims=None, axes=None, time=None, n_neurons={}, figsize=(8 Inputs: - | :code:`spikes` (:code:`dict(torch.Tensor)`): Contains spiking data for groups of neurons of interest. + | :code:`spikes` (:code:`dict(torch.Tensor)`): Contains + spiking data for groups of neurons of interest. | :code:`ims` (:code:`list(matplotlib.image.AxesImage)`): Used for re-drawing the spike plots. | :code:`axes` (:code:`list(matplotlib.axes.Axes)`): Used for re-drawing the spike plots. - | :code:`time` (:code:`tuple(int)`): Plot spiking activity of neurons in the given time range. Default is entire simulation time. - | :code:`n_neurons` (:code:`dict(tuple(int))`): Plot spiking activity of neurons in the given range of neurons. Default is all neurons. + | :code:`time` (:code:`tuple(int)`): Plot spiking activity of neurons + in the given time range. Default is entire simulation time. + | :code:`n_neurons` (:code:`dict(tuple(int))`): Plot spiking activity + of neurons in the given range of neurons. Default is all neurons. | :code:`figsize` (:code:`tuple(int)`): Horizontal, vertical figure size in inches. Returns: @@ -81,8 +84,10 @@ def plot_spikes(spikes, ims=None, axes=None, time=None, n_neurons={}, figsize=(8 break if len(n_neurons.keys()) != 0: - assert len(n_neurons.keys()) <= n_subplots, 'n_neurons argument needs fewer entries than n_subplots' - assert all(key in spikes.keys() for key in n_neurons.keys()), 'n_neurons keys must be subset of spikes keys' + assert len(n_neurons.keys()) <= n_subplots, \ + 'n_neurons argument needs fewer entries than n_subplots' + assert all(key in spikes.keys() for key in n_neurons.keys()), \ + 'n_neurons keys must be subset of spikes keys' # Use all neurons if no argument provided. for key, val in spikes.items(): @@ -95,15 +100,19 @@ def plot_spikes(spikes, ims=None, axes=None, time=None, n_neurons={}, figsize=(8 if n_subplots == 1: for datum in spikes.items(): - ims.append(axes.imshow(spikes[datum[0]][n_neurons[datum[0]][0]:n_neurons[datum[0]][1], time[0]:time[1]], cmap='binary')) - + ims.append(axes.imshow(spikes[datum[0]][n_neurons[datum[0]][0]:n_neurons[datum[0]][1], + time[0]:time[1]], + cmap='binary')) + args = (datum[0], n_neurons[datum[0]][0], n_neurons[datum[0]][1], time[0], time[1]) plt.title('%s spikes for neurons (%d - %d) from t = %d to %d ' % args) plt.xlabel('Simulation time'); plt.ylabel('Neuron index') axes.set_aspect('auto') else: for i, datum in enumerate(spikes.items()): - ims.append(axes[i].imshow(datum[1][n_neurons[datum[0]][0]:n_neurons[datum[0]][1], time[0]:time[1]], cmap='binary')) + ims.append(axes[i].imshow(datum[1][n_neurons[datum[0]][0]:n_neurons[datum[0]][1], + time[0]:time[1]], + cmap='binary')) args = (datum[0], n_neurons[datum[0]][0], n_neurons[datum[0]][1], time[0], time[1]) axes[i].set_title('%s spikes for neurons (%d - %d) from t = %d to %d ' % args) @@ -195,7 +204,8 @@ def plot_conv2d_weights(weights, wmin=0.0, wmax=1.0, im=None, figsize=(5, 5)): for j in range(n_sqrt): if i * n_sqrt + j < weights.size(0): fltr = weights[i * n_sqrt + j].view(height, width) - reshaped[i * height : (i + 1) * height, (j % n_sqrt) * width : ((j % n_sqrt) + 1) * width] = fltr + reshaped[i * height : (i + 1) * height, + (j % n_sqrt) * width : ((j % n_sqrt) + 1) * width] = fltr if not im: fig, ax = plt.subplots(figsize=figsize) @@ -228,13 +238,17 @@ def plot_assignments(assignments, im=None, figsize=(5, 5), classes=None): Inputs: | :code:`assignments` (:code:`torch.Tensor`): Vector of neuron label assignments. - | :code:`im` (:code:`matplotlib.image.AxesImage`): Used for re-drawing the assignments plot. - | :code:`figsize` (:code:`tuple(int)`): Horizontal, vertical figure size in inches. - | :code:`classes` (:code:`iterable`): Iterable of labels for colorbar ticks corresponding to data labels. + | :code:`im` (:code:`matplotlib.image.AxesImage`): + Used for re-drawing the assignments plot. + | :code:`figsize` (:code:`tuple(int)`): + Horizontal, vertical figure size in inches. + | :code:`classes` (:code:`iterable`): Iterable of + labels for colorbar ticks corresponding to data labels. Returns: - | (:code:`im` (:code:`matplotlib.image.AxesImage`): Used for re-drawing the assigments plot. + | (:code:`im` (:code:`matplotlib.image.AxesImage`): + Used for re-drawing the assigments plot. ''' if not im: fig, ax = plt.subplots(figsize=figsize) @@ -242,12 +256,17 @@ def plot_assignments(assignments, im=None, figsize=(5, 5), classes=None): color = plt.get_cmap('RdBu', 11) im = ax.matshow(assignments, cmap=color, vmin=-1.5, vmax=9.5) - div = make_axes_locatable(ax); cax = div.append_axes("right", size="5%", pad=0.05) + div = make_axes_locatable(ax); cax = div.append_axes("right", + size="5%", + pad=0.05) if classes is None: plt.colorbar(im, cax=cax, ticks=np.arange(-1, 10)) else: - cbar = plt.colorbar(im, cax=cax, ticks=np.arange(-1, len(classes))) + cbar = plt.colorbar(im, + cax=cax, + ticks=np.arange(-1, len(classes))) + cbar.ax.set_yticklabels(classes) ax.set_xticks(()); ax.set_yticks(()) @@ -264,13 +283,17 @@ def plot_performance(performances, ax=None, figsize=(7, 4)): Inputs: - | :code:`performances` (:code:`dict(list(float))`): Lists of training accuracy estimates per voting scheme. - | :code:`ax` (:code:`matplotlib.axes.Axes`): Used for re-drawing the performance plot. - | :code:`figsize` (:code:`tuple(int)`): Horizontal, vertical figure size in inches. + | :code:`performances` (:code:`dict(list(float))`): + Lists of training accuracy estimates per voting scheme. + | :code:`ax` (:code:`matplotlib.axes.Axes`): + Used for re-drawing the performance plot. + | :code:`figsize` (:code:`tuple(int)`): + Horizontal, vertical figure size in inches. Returns: - | (:code:`ax` (:code:`matplotlib.axes.Axes`): Used for re-drawing the performance plot. + | (:code:`ax` (:code:`matplotlib.axes.Axes`): + Used for re-drawing the performance plot. ''' if not ax: _, ax = plt.subplots(figsize=figsize) @@ -295,18 +318,25 @@ def plot_general(monitor=None, ims=None, axes=None, labels=None, parameters=None Inputs: - | :code:`monitor` (:code:`monitors.Monitor`): Contains state variables to be plotted. - | :code:`ims` (:code:`list(matplotlib.image.AxesImage)`): Used for re-drawing plots. - | :code:`axes` (:code:`list(matplotlib.axes.Axes)`): Used for re-drawing plots. - | :code:`labels` (:code:`dict(dict(string))`): Used to set axis labels and titles for plotted variables. - | :code:`parameters` (:code:`dict(dict(tuples(int)))`): Set time, number of neurons for plotted variables. - | :code:`figsize` (:code:`tuple(int)`): Horizontal, vertical figure size in inches. + | :code:`monitor` (:code:`monitors.Monitor`): + Contains state variables to be plotted. + | :code:`ims` (:code:`list(matplotlib.image.AxesImage)`): + Used for re-drawing plots. + | :code:`axes` (:code:`list(matplotlib.axes.Axes)`): + Used for re-drawing plots. + | :code:`labels` (:code:`dict(dict(string))`): + Used to set axis labels and titles for plotted variables. + | :code:`parameters` (:code:`dict(dict(tuples(int)))`): + Set time, number of neurons for plotted variables. + | :code:`figsize` (:code:`tuple(int)`): + Horizontal, vertical figure size in inches. Returns: - | (:code:`ims` (:code:`list(matplotlib.axes.Axes)): Used for re-drawing plots. - | (:code:`axes` (:code:`list(matplotlib.image.AxesImage)`): Used for re-drawing plots. - + | (:code:`ims` (:code:`list(matplotlib.axes.Axes)): + Used for re-drawing plots. + | (:code:`axes` (:code:`list(matplotlib.image.AxesImage)`): + Used for re-drawing plots. ''' default = {'xlabel' : 'Simulation time', 'ylabel' : 'Index'} @@ -357,30 +387,40 @@ def plot_general(monitor=None, ims=None, axes=None, labels=None, parameters=None for var in monitor.state_vars: # For Weights if parameters[var]['cmap'] == 'hot_r' or parameters[var]['cmap'] == 'hot': - ims.append(axes.matshow(monitor.get(var)[parameters[var]['n_neurons'][0]:parameters[var]['n_neurons'][1], parameters[var]['time'][0]:parameters[var]['time'][1]])) + ims.append(axes.matshow(monitor.get(var)[parameters[var]['n_neurons'][0]:parameters[var]['n_neurons'][1], + parameters[var]['time'][0]:parameters[var]['time'][1]])) else: - ims.append(axes.imshow(monitor.get(var)[parameters[var]['n_neurons'][0]:parameters[var]['n_neurons'][1], parameters[var]['time'][0]:parameters[var]['time'][1]])) + ims.append(axes.imshow(monitor.get(var)[parameters[var]['n_neurons'][0]:parameters[var]['n_neurons'][1], + parameters[var]['time'][0]:parameters[var]['time'][1]])) + + plt.title(labels[var]['title']) + plt.xlabel(labels[var]['xlabel']) + plt.ylabel(labels[var]['ylabel']) - plt.title(labels[var]['title']); plt.xlabel(labels[var]['xlabel']); plt.ylabel(labels[var]['ylabel']) axes.set_aspect('auto') else: # Plot each monitor variable at a time for i, var in enumerate(monitor.state_vars): if parameters[var]['cmap'] == 'hot_r' or parameters[var]['cmap'] == 'hot': - ims.append(axes[i].matshow(monitor.get(var)[parameters[var]['n_neurons'][0]:parameters[var]['n_neurons'][1], parameters[var]['time'][0]:parameters[var]['time'][1]])) + ims.append(axes[i].matshow(monitor.get(var)[parameters[var]['n_neurons'][0]:parameters[var]['n_neurons'][1], + parameters[var]['time'][0]:parameters[var]['time'][1]])) else: - ims.append(axes[i].imshow(monitor.get(var)[parameters[var]['n_neurons'][0]:parameters[var]['n_neurons'][1], parameters[var]['time'][0]:parameters[var]['time'][1]])) + ims.append(axes[i].imshow(monitor.get(var)[parameters[var]['n_neurons'][0]:parameters[var]['n_neurons'][1], + parameters[var]['time'][0]:parameters[var]['time'][1]])) + + axes.set_title(labels[var]['title']) + axes.set_xlabel(labels[var]['xlabel']) + axes.set_ylabel(labels[var]['ylabel']) - axes.set_title(labels[var]['title']); axes.set_xlabel(labels[var]['xlabel']); axes.set_ylabel(labels[var]['ylabel']) - axes.set_aspect('auto') + axes.set_aspect('auto') # axes given else: assert(len(ims) == n_subplots) return ims, axes - - + + def plot_voltages(voltages, ims=None, axes=None, time=None, n_neurons={}, figsize=(8, 4.5)): ''' Plot voltages for any group(s) of neurons. @@ -418,8 +458,6 @@ def plot_voltages(voltages, ims=None, axes=None, time=None, n_neurons={}, figsiz assert(len(n_neurons.keys()) <= n_subplots) # Keys given must be same as the ones used in spikes dict assert(all(key in voltages.keys() for key in n_neurons.keys())) - # Checking to that given n_neurons per neuron layer is valid - assert(all(n_neurons[key][0] >= 0 and n_neurons[key][1] <= val.shape[0] for key, val in voltages.items() if key in n_neurons.keys()) == True) for key, val in voltages.items(): if key not in n_neurons.keys(): @@ -431,16 +469,26 @@ def plot_voltages(voltages, ims=None, axes=None, time=None, n_neurons={}, figsiz if n_subplots == 1: # Plotting only one image for datum in voltages.items(): - ims.append(axes.matshow(voltages[datum[0]][n_neurons[datum[0]][0]:n_neurons[datum[0]][1], time[0]:time[1]])) - plt.title('%s voltages for neurons (%d - %d) from t = %d to %d '% (datum[0], n_neurons[datum[0]][0], n_neurons[datum[0]][1], time[0], time[1])) + ims.append(axes.matshow(voltages[datum[0]][n_neurons[datum[0]][0]:n_neurons[datum[0]][1], + time[0]:time[1]])) + plt.title('%s voltages for neurons (%d - %d) from t = %d to %d ' % (datum[0], + n_neurons[datum[0]][0], + n_neurons[datum[0]][1], + time[0], + time[1])) plt.xlabel('Time (ms)'); plt.ylabel('Neuron index') axes.set_aspect('auto') else: # Plot each layer at a time for i, datum in enumerate(voltages.items()): - ims.append(axes[i].matshow(datum[1][n_neurons[datum[0]][0]:n_neurons[datum[0]][1], time[0]:time[1]])) - axes[i].set_title('%s voltages for neurons (%d - %d) from t = %d to %d '% (datum[0], n_neurons[datum[0]][0], n_neurons[datum[0]][1], time[0], time[1])) + ims.append(axes[i].matshow(datum[1][n_neurons[datum[0]][0]:n_neurons[datum[0]][1], + time[0]:time[1]])) + axes[i].set_title('%s voltages for neurons (%d - %d) from t = %d to %d ' % (datum[0], + n_neurons[datum[0]][0], + n_neurons[datum[0]][1], + time[0], + time[1])) for ax in axes: ax.set_aspect('auto') @@ -452,16 +500,26 @@ def plot_voltages(voltages, ims=None, axes=None, time=None, n_neurons={}, figsiz if n_subplots == 1: # Plotting only one image for datum in voltages.items(): axes.clear() - axes.matshow(voltages[datum[0]][n_neurons[datum[0]][0]:n_neurons[datum[0]][1], time[0]:time[1]]) - axes.set_title('%s voltages for neurons (%d - %d) from t = %d to %d '% (datum[0], n_neurons[datum[0]][0], n_neurons[datum[0]][1], time[0], time[1])) + axes.matshow(voltages[datum[0]][n_neurons[datum[0]][0]:n_neurons[datum[0]][1], + time[0]:time[1]]) + axes.set_title('%s voltages for neurons (%d - %d) from t = %d to %d ' % (datum[0], + n_neurons[datum[0]][0], + n_neurons[datum[0]][1], + time[0], + time[1])) axes.set_aspect('auto') else: # Plot each layer at a time for i, datum in enumerate(voltages.items()): axes[i].clear() - axes[i].matshow(voltages[datum[0]][n_neurons[datum[0]][0]:n_neurons[datum[0]][1], time[0]:time[1]]) - axes[i].set_title('%s voltages for neurons (%d - %d) from t = %d to %d '% (datum[0], n_neurons[datum[0]][0], n_neurons[datum[0]][1], time[0], time[1])) + axes[i].matshow(voltages[datum[0]][n_neurons[datum[0]][0]:n_neurons[datum[0]][1], + time[0]:time[1]]) + axes[i].set_title('%s voltages for neurons (%d - %d) from t = %d to %d ' % (datum[0], + n_neurons[datum[0]][0], + n_neurons[datum[0]][1], + time[0], + time[1])) for ax in axes: ax.set_aspect('auto') diff --git a/bindsnet/analysis/visualization.py b/bindsnet/analysis/visualization.py index 74839b8b..6dcd709b 100644 --- a/bindsnet/analysis/visualization.py +++ b/bindsnet/analysis/visualization.py @@ -14,10 +14,9 @@ def plot_weights_movie(ws, sample_every=1): Inputs: - | :code:`ws` (:code:`numpy.array`): Numpy array of shape :code:`[N_examples, source, target, time]` - | :code:`sample_every` (:code:`int`): Sub-sample using this parameter. For example if :code:`time` is - too large (500), set this parameter to 20 to sample weights - every 20 iterations. + | :code:`ws` (:code:`numpy.array`): Numpy array + of shape :code:`[n_examples, source, target, time]` + | :code:`sample_every` (:code:`int`): Sub-sample using this parameter. """ weights = [] @@ -49,10 +48,13 @@ def plot_spike_trains_for_example(spikes, n_ex=None, top_k=None, indices=None): Inputs: - | :code:`spikes` (:code:`torch.Tensor (n_examples, n_neurons, time)`): Spiking train data for a population of neurons for one example. - | :code:`n_ex` (:code:`int`): Allows user to pick which example to plot spikes for. Must be >= 0. + | :code:`spikes` (:code:`torch.Tensor (n_examples, n_neurons, time)`): + Spiking train data for a population of neurons for one example. + | :code:`n_ex` (:code:`int`): Allows user to pick + which example to plot spikes for. Must be >= 0. | :code:`top_k` (:code:`int`): Plot k neurons that spiked the most for n_ex example. - | :code:`indices` (:code:`list(int)`): Plot specific neurons' spiking activity instead of top_k. Meant to replace top_k. + | :code:`indices` (:code:`list(int)`): Plot specific neurons' + spiking activity instead of top_k. Meant to replace top_k. ''' assert (n_ex is not None and n_ex >= 0 and n_ex < spikes.shape[0]) @@ -84,11 +86,16 @@ def plot_voltage(voltage, n_ex=0, n_neuron=0, time=None, threshold=None): Inputs: - | :code:`voltage` (:code:`torch.Tensor` or :code:`numpy.array`): Tensor or array of shape :code:`[n_examples, n_neurons, time]`. - | :code:`n_ex` (:code:`int`): Allows user to pick which example to plot voltage for. - | :code:`n_neuron` (:code:`int`): Neuron index for which to plot voltages for. - | :code:`time` (:code:`tuple(int)`): Plot spiking activity of neurons between the given range of time. - | :code:`threshold` (:code:`float`): Neuron spiking threshold. Will be shown on the plot. + | :code:`voltage` (:code:`torch.Tensor` or :code:`numpy.array`): + Tensor or array of shape :code:`[n_examples, n_neurons, time]`. + | :code:`n_ex` (:code:`int`): Allows user + to pick which example to plot voltage for. + | :code:`n_neuron` (:code:`int`): Neuron + index for which to plot voltages for. + | :code:`time` (:code:`tuple(int)`): Plot spiking + activity of neurons between the given range of time. + | :code:`threshold` (:code:`float`): Neuron + spiking threshold. Will be shown on the plot. ''' assert (n_ex >= 0 and n_neuron >= 0) @@ -105,7 +112,9 @@ def plot_voltage(voltage, n_ex=0, n_neuron=0, time=None, threshold=None): plt.figure() plt.plot(voltage[n_ex, n_neuron, timer]) - plt.xlabel('Simulation Time'); plt.ylabel('Voltage'); plt.title('Membrane voltage of neuron %d for example %d'%(n_neuron, n_ex+1)) + plt.xlabel('Simulation Time') + plt.ylabel('Voltage') + plt.title('Membrane voltage of neuron %d for example %d' % (n_neuron, n_ex + 1)) locs, labels = plt.xticks() locs = range(int(locs[1]), int(locs[-1]), 10) plt.xticks(locs, time_ticks) diff --git a/bindsnet/datasets/__init__.py b/bindsnet/datasets/__init__.py index e6d62f42..a613af5b 100644 --- a/bindsnet/datasets/__init__.py +++ b/bindsnet/datasets/__init__.py @@ -29,8 +29,10 @@ def __init__(self, path='.', download=False): Inputs: - | :code:`path` (:code:`str`): Pathname of directory in which to store the dataset. - | :code:`download` (:code:`bool`): Whether or not to download the dataset (requires internet connection). + | :code:`path` (:code:`str`): Pathname of + directory in which to store the dataset. + | :code:`download` (:code:`bool`): Whether or not to + download the dataset (requires internet connection). ''' if not os.path.isdir(path): os.makedirs(path) @@ -85,12 +87,15 @@ class MNIST(Dataset): def __init__(self, path=os.path.join('data', 'MNIST'), download=False): ''' - Constructor for the :code:`MNIST` object. Makes the data directory if it doesn't already exist. + Constructor for the :code:`MNIST` object. Makes + the data directory if it doesn't already exist. Inputs: - | :code:`path` (:code:`str`): Pathname of directory in which to store the dataset. - | :code:`download` (:code:`bool`): Whether or not to download the dataset (requires internet connection). + | :code:`path` (:code:`str`): Pathname of + directory in which to store the dataset. + | :code:`download` (:code:`bool`): Whether or not + to download the dataset (requires internet connection). ''' super().__init__(path, download) @@ -113,7 +118,8 @@ def get_train(self): # Serialize image data on disk for next time. p.dump(images, open(os.path.join(self.path, MNIST.train_images_pickle), 'wb')) else: - raise FileNotFoundError('Dataset not found on disk; specify \'MNIST(..., download=True, ...)\' to allow downloads.') + msg = 'Dataset not found on disk; specify \'download=True\' to allow downloads.' + raise FileNotFoundError(msg) else: # Load image data from disk if it has already been processed. print('Loading training images from serialized object file.\n') @@ -129,7 +135,8 @@ def get_train(self): # Serialize label data on disk for next time. p.dump(labels, open(os.path.join(self.path, MNIST.train_labels_pickle), 'wb')) else: - raise FileNotFoundError('Dataset not found on disk; specify \'MNIST(..., download=True, ...)\' to allow downloads.') + msg = 'Dataset not found on disk; specify \'download=True\' to allow downloads.' + raise FileNotFoundError(msg) else: # Load label data from disk if it has already been processed. print('Loading training labels from serialized object file.\n') @@ -156,7 +163,8 @@ def get_test(self): # Serialize image data on disk for next time. p.dump(images, open(os.path.join(self.path, MNIST.test_images_pickle), 'wb')) else: - raise FileNotFoundError('Dataset not found on disk; specify \'MNIST(..., download=True, ...)\' to allow downloads.') + msg = 'Dataset not found on disk; specify \'download=True\' to allow downloads.' + raise FileNotFoundError(msg) else: # Load image data from disk if it has already been processed. print('Loading test images from serialized object file.\n') @@ -172,7 +180,8 @@ def get_test(self): # Serialize image data on disk for next time. p.dump(labels, open(os.path.join(self.path, MNIST.test_labels_pickle), 'wb')) else: - raise FileNotFoundError('Dataset not found on disk; specify \'MNIST(..., download=True, ...)\' to allow downloads.') + msg = 'Dataset not found on disk; specify \'download=True\' to allow downloads.' + raise FileNotFoundError(msg) else: # Load label data from disk if it has already been processed. print('Loading test labels from serialized object file.\n') @@ -187,7 +196,8 @@ def _download(self, url, filename): Inputs: | :code:`url` (:code:`str`): The URL of the data file to be downloaded. - | :code:`filename` (:code:`str`): The name of the file to save the downloaded data to. + | :code:`filename` (:code:`str`): The name + of the file to save the downloaded data to. ''' urlretrieve(url, os.path.join(self.path, filename + '.gz')) with gzip.open(os.path.join(self.path, filename + '.gz'), 'rb') as _in: @@ -200,7 +210,8 @@ def process_images(self, filename): Inputs: - | :code:`filename` (:code:`str`): Name of the file containing MNIST images to load. + | :code:`filename` (:code:`str`): Name of + the file containing MNIST images to load. Returns: @@ -267,7 +278,8 @@ def process_labels(self, filename): class SpokenMNIST(Dataset): ''' - Handles loading and saving of the Spoken MNIST audio dataset `(link) `_. + Handles loading and saving of the Spoken MNIST audio dataset + `(link) `_. ''' train_pickle = 'train.p' test_pickle = 'test.p' @@ -284,12 +296,15 @@ class SpokenMNIST(Dataset): def __init__(self, path=os.path.join('data', 'SpokenMNIST'), download=False): ''' - Constructor for the :code:`SpokenMNIST` object. Makes the data directory if it doesn't already exist. + Constructor for the :code:`SpokenMNIST` object. Makes + the data directory if it doesn't already exist. Inputs: - | :code:`path` (:code:`str`): Pathname of directory in which to store the dataset. - | :code:`download` (:code:`bool`): Whether or not to download the dataset (requires internet connection). + | :code:`path` (:code:`str`): Pathname of + directory in which to store the dataset. + | :code:`download` (:code:`bool`): Whether or not to + download the dataset (requires internet connection). ''' super().__init__(path, download) self.zip_path = os.path.join(path, 'repo.zip') @@ -308,7 +323,8 @@ def get_train(self, split=0.8): | :code:`labels` (:code:list(`torch.Tensor`)): The Spoken MNIST training labels. ''' split_index = int(split * SpokenMNIST.n_files) - + path = os.path.join(self.path, '_'.join([SpokenMNIST.train_pickle, str(split)])) + if not all([os.path.isfile(os.path.join(self.path, f)) for f in self.files]): # Download data if it isn't on disk. if self.download: @@ -319,20 +335,21 @@ def get_train(self, split=0.8): audio, labels = self.process_data(SpokenMNIST.files[:split_index]) # Serialize image data on disk for next time. - p.dump((audio, labels), open(os.path.join(self.path, '_'.join([SpokenMNIST.train_pickle, str(split)])), 'wb')) + p.dump((audio, labels), open(path, 'wb')) else: - raise FileNotFoundError('Dataset not found on disk; specify \'SpokenMNIST(..., download=True, ...)\' to allow downloads.') + msg = 'Dataset not found on disk; specify \'download=True\' to allow downloads.' + raise FileNotFoundError(msg) else: - if not os.path.isdir(os.path.join(self.path, '_'.join([SpokenMNIST.train_pickle, str(split)]))): + if not os.path.isdir(path): # Process image and label data if pickled file doesn't exist. audio, labels = self.process_data(SpokenMNIST.files) # Serialize image data on disk for next time. - p.dump((audio, labels), open(os.path.join(self.path, '_'.join([SpokenMNIST.train_pickle, str(split)])), 'wb')) + p.dump((audio, labels), open(path, 'wb')) else: # Load image data from disk if it has already been processed. print('Loading training data from serialized object file.\n') - audio, labels = p.load(open(os.path.join(self.path, '_'.join([SpokenMNIST.train_pickle, str(split)])), 'rb')) + audio, labels = p.load(open(path, 'rb')) return audio, torch.Tensor(labels) @@ -346,6 +363,7 @@ def get_test(self, split=0.8): | :code:`labels` (:code:`torch.Tensor`): The Spoken MNIST training labels. ''' split_index = int(split * SpokenMNIST.n_files) + path = os.path.join(self.path, '_'.join([SpokenMNIST.test_pickle, str(split)])) if not all([os.path.isfile(os.path.join(self.path, f)) for f in self.files]): # Download data if it isn't on disk. @@ -357,20 +375,21 @@ def get_test(self, split=0.8): audio, labels = self.process_data(SpokenMNIST.files[split_index:]) # Serialize image data on disk for next time. - p.dump((audio, labels), open(os.path.join(self.path, '_'.join([SpokenMNIST.train_pickle, str(split)])), 'wb')) + p.dump((audio, labels), open(path, 'wb')) else: - raise FileNotFoundError('Dataset not found on disk; specify \'SpokenMNIST(..., download=True, ...)\' to allow downloads.') + msg = 'Dataset not found on disk; specify \'download=True\' to allow downloads.' + raise FileNotFoundError(msg) else: - if not os.path.isdir(os.path.join(self.path, '_'.join([SpokenMNIST.train_pickle, str(split)]))): + if not os.path.isdir(path): # Process image and label data if pickled file doesn't exist. audio, labels = self.process_data(SpokenMNIST.files) # Serialize image data on disk for next time. - p.dump((audio, labels), open(os.path.join(self.path, '_'.join([SpokenMNIST.train_pickle, str(split)])), 'wb')) + p.dump((audio, labels), open(path, 'wb')) else: # Load image data from disk if it has already been processed. print('Loading test data from serialized object file.\n') - audio, labels = p.load(open(os.path.join(self.path, '_'.join([SpokenMNIST.train_pickle, str(split)])), 'rb')) + audio, labels = p.load(open(path, 'rb')) return audio, torch.Tensor(labels) @@ -388,8 +407,9 @@ def _download(self): z.extractall(path=self.path) z.close() - for f in os.listdir(os.path.join(self.path, 'free-spoken-digit-dataset-master', 'recordings')): - shutil.move(os.path.join(self.path, 'free-spoken-digit-dataset-master', 'recordings', f), os.path.join(self.path)) + path = os.path.join(self.path, 'free-spoken-digit-dataset-master', 'recordings') + for f in os.listdir(path): + shutil.move(os.path.join(path, f), os.path.join(self.path)) cwd = os.getcwd() os.chdir(self.path) @@ -505,6 +525,7 @@ def get_train(self): | :code:`images` (:code:`torch.Tensor`): The CIFAR-10 training images. | :code:`labels` (:code:`torch.Tensor`): The CIFAR-10 training labels. ''' + path = os.path.join(self.path, CIFAR10.train_pickle) if not os.path.isdir(os.path.join(self.path, CIFAR10.data_directory)): # Download data if it isn't on disk. if self.download: @@ -513,20 +534,21 @@ def get_train(self): images, labels = self.process_data(CIFAR10.train_files) # Serialize image data on disk for next time. - p.dump((images, labels), open(os.path.join(self.path, CIFAR10.train_pickle), 'wb')) + p.dump((images, labels), open(path, 'wb')) else: - raise FileNotFoundError('Dataset not found on disk; specify \'CIFAR10(..., download=True, ...)\' to allow downloads.') + msg = 'Dataset not found on disk; specify \'download=True\' to allow downloads.' + raise FileNotFoundError(msg) else: - if not os.path.isdir(os.path.join(self.path, CIFAR10.train_pickle)): + if not os.path.isdir(path): # Process image and label data if pickled file doesn't exist. images, labels = self.process_data(CIFAR10.train_files) # Serialize image data on disk for next time. - p.dump((images, labels), open(os.path.join(self.path, CIFAR10.train_pickle), 'wb')) + p.dump((images, labels), open(path, 'wb')) else: # Load image data from disk if it has already been processed. print('Loading training images from serialized object file.\n') - images, labels = p.load(open(os.path.join(self.path, CIFAR10.train_pickle), 'rb')) + images, labels = p.load(open(path, 'rb')) return torch.Tensor(images), torch.Tensor(labels) @@ -539,6 +561,7 @@ def get_test(self): | :code:`images` (:code:`torch.Tensor`): The CIFAR-10 test images. | :code:`labels` (:code:`torch.Tensor`): The CIFAR-10 test labels. ''' + path = os.path.join(self.path, CIFAR10.test_pickle) if not os.path.isdir(os.path.join(self.path, CIFAR10.data_directory)): # Download data if it isn't on disk. if self.download: @@ -547,20 +570,21 @@ def get_test(self): images, labels = self.process_data(CIFAR10.test_files) # Serialize image data on disk for next time. - p.dump((images, labels), open(os.path.join(self.path, CIFAR10.test_pickle), 'wb')) + p.dump((images, labels), open(path, 'wb')) else: - raise FileNotFoundError('Dataset not found on disk; specify \'CIFAR10(..., download=True, ...)\' to allow downloads.') + msg = 'Dataset not found on disk; specify \'download=True\' to allow downloads.' + raise FileNotFoundError(msg) else: - if not os.path.isdir(os.path.join(self.path, CIFAR10.test_pickle)): + if not os.path.isdir(path): # Process image and label data if pickled file doesn't exist. images, labels = self.process_data(CIFAR10.test_files) # Serialize image data on disk for next time. - p.dump((images, labels), open(os.path.join(self.path, CIFAR10.test_pickle), 'wb')) + p.dump((images, labels), open(path, 'wb')) else: # Load image data from disk if it has already been processed. print('Loading test images from serialized object file.\n') - images, labels = p.load(open(os.path.join(self.path, CIFAR10.test_pickle), 'rb')) + images, labels = p.load(open(path, 'rb')) return torch.Tensor(images), torch.Tensor(labels) @@ -583,11 +607,13 @@ def process_data(self, filenames): Inputs: - | :code:`filename` (:code:`str`): Name of the file containing CIFAR-10 images and labels to load. + | :code:`filename` (:code:`str`): Name of the file + containing CIFAR-10 images and labels to load. Returns: - | (:code:`tuple(numpy.ndarray)`): Two :code:`numpy` arrays with image and label data, respectively. + | (:code:`tuple(numpy.ndarray)`): Two :code:`numpy` + arrays with image and label data, respectively. ''' d = {'data' : [], 'labels' : []} for filename in filenames: @@ -617,12 +643,15 @@ class CIFAR100(Dataset): def __init__(self, path=os.path.join('data', 'CIFAR100'), download=False): ''' - Constructor for the :code:`CIFAR100` object. Makes the data directory if it doesn't already exist. + Constructor for the :code:`CIFAR100` object. Makes + the data directory if it doesn't already exist. Inputs: - | :code:`path` (:code:`str`): Pathname of directory in which to store the dataset. - | :code:`download` (:code:`bool`): Whether or not to download the dataset (requires internet connection). + | :code:`path` (:code:`str`): Pathname of + directory in which to store the dataset. + | :code:`download` (:code:`bool`): Whether or not to + download the dataset (requires internet connection). ''' super().__init__(path, download) self.data_path = os.path.join(self.path, CIFAR100.data_directory) @@ -636,6 +665,7 @@ def get_train(self): | :code:`images` (:code:`torch.Tensor`): The CIFAR-100 training images. | :code:`labels` (:code:`torch.Tensor`): The CIFAR-100 training labels. ''' + path = os.path.join(self.path, CIFAR100.train_pickle) if not os.path.isdir(os.path.join(self.path, CIFAR100.data_directory)): # Download data if it isn't on disk. if self.download: @@ -644,20 +674,21 @@ def get_train(self): images, labels = self.process_data(CIFAR100.train_files) # Serialize image data on disk for next time. - p.dump((images, labels), open(os.path.join(self.path, CIFAR100.train_pickle), 'wb')) + p.dump((images, labels), open(path, 'wb')) else: - raise FileNotFoundError('Dataset not found on disk; specify \'CIFAR100(..., download=True, ...)\' to allow downloads.') + msg = 'Dataset not found on disk; specify \'download=True\' to allow downloads.' + raise FileNotFoundError(msg) else: - if not os.path.isdir(os.path.join(self.path, CIFAR100.train_pickle)): + if not os.path.isdir(path): # Process image and label data if pickled file doesn't exist. images, labels = self.process_data(CIFAR100.train_files) # Serialize image data on disk for next time. - p.dump((images, labels), open(os.path.join(self.path, CIFAR100.train_pickle), 'wb')) + p.dump((images, labels), open(path, 'wb')) else: # Load image data from disk if it has already been processed. print('Loading training images from serialized object file.\n') - images, labels = p.load(open(os.path.join(self.path, CIFAR100.train_pickle), 'rb')) + images, labels = p.load(open(path, 'rb')) return torch.Tensor(images), torch.Tensor(labels) @@ -670,6 +701,7 @@ def get_test(self): | :code:`images` (:code:`torch.Tensor`): The CIFAR-100 test images. | :code:`labels` (:code:`torch.Tensor`): The CIFAR-100 test labels. ''' + path = os.path.join(self.path, CIFAR100.test_pickle) if not os.path.isdir(os.path.join(self.path, CIFAR100.data_directory)): # Download data if it isn't on disk. if self.download: @@ -678,20 +710,21 @@ def get_test(self): images, labels = self.process_data(CIFAR100.test_files) # Serialize image data on disk for next time. - p.dump((images, labels), open(os.path.join(self.path, CIFAR100.test_pickle), 'wb')) + p.dump((images, labels), open(path, 'wb')) else: - raise FileNotFoundError('Dataset not found on disk; specify \'CIFAR100(..., download=True, ...)\' to allow downloads.') + msg = 'Dataset not found on disk; specify \'download=True\' to allow downloads.' + raise FileNotFoundError(msg) else: - if not os.path.isdir(os.path.join(self.path, CIFAR10.test_pickle)): + if not os.path.isdir(path): # Process image and label data if pickled file doesn't exist. images, labels = self.process_data(CIFAR100.test_files) # Serialize image data on disk for next time. - p.dump((images, labels), open(os.path.join(self.path, CIFAR100.test_pickle), 'wb')) + p.dump((images, labels), open(path, 'wb')) else: # Load image data from disk if it has already been processed. print('Loading test images from serialized object file.\n') - images, labels = p.load(open(os.path.join(self.path, CIFAR100.test_pickle), 'rb')) + images, labels = p.load(open(path, 'rb')) return torch.Tensor(images), torch.Tensor(labels) diff --git a/bindsnet/datasets/preprocess.py b/bindsnet/datasets/preprocess.py index 950846b0..6abf22e6 100644 --- a/bindsnet/datasets/preprocess.py +++ b/bindsnet/datasets/preprocess.py @@ -12,13 +12,12 @@ def gray_scale(im): | :code:`im` (:code:`numpy.array`): Grayscaled image ''' - im = cv2.cvtColor(im, cv2.COLOR_RGB2GRAY) - return im + return cv2.cvtColor(im, cv2.COLOR_RGB2GRAY) def crop(im, x1, x2, y1, y2): return im[x1:x2, y1:y2, :] - + def binary_image(im): ''' @@ -32,7 +31,7 @@ def binary_image(im): | :code:`im` (:code:`numpy.array`): Black and white image. ''' - ret, im = cv2.threshold(im, 0, 1, cv2.THRESH_BINARY) + _, im = cv2.threshold(im, 0, 1, cv2.THRESH_BINARY) return im @@ -50,6 +49,4 @@ def subsample(im, x, y): | :code:`im` (:code:`numpy.array`): Rescaled image. ''' - im = cv2.resize(im, (x, y)) - return im - + return cv2.resize(im, (x, y)) diff --git a/bindsnet/encoding/__init__.py b/bindsnet/encoding/__init__.py index 7f327e2b..2d559c27 100644 --- a/bindsnet/encoding/__init__.py +++ b/bindsnet/encoding/__init__.py @@ -4,7 +4,9 @@ def bernoulli(datum, time=None, **kwargs): ''' - Generates Bernoulli-distributed spike trains based on input intensity. Inputs must be non-negative. Spikes correspond to successful Bernoulli trials, with success probability equal to (normalized in [0, 1]) input value. + Generates Bernoulli-distributed spike trains based on input intensity. + Inputs must be non-negative. Spikes correspond to successful Bernoulli + trials, with success probability equal to (normalized in [0, 1]) input value. Inputs: @@ -13,11 +15,13 @@ def bernoulli(datum, time=None, **kwargs): Keyword arguments: - | :code:`max_prob` (:code:`float`): Maximum probability of spike per Bernoulli trial. + | :code:`max_prob` (:code:`float`): Maximum + probability of spike per Bernoulli trial. Returns: - | (:code:`torch.Tensor`): Tensor of shape :code:`[time, n_1, ..., n_k]` of Bernoulli-distributed spikes. + | (:code:`torch.Tensor`): Tensor of shape :code:`[time, n_1, ..., n_k]` + of Bernoulli-distributed spikes. ''' # Setting kwargs. max_prob = kwargs.get('max_prob', 1.0) @@ -48,7 +52,8 @@ def bernoulli_loader(data, time=None, **kwargs): Inputs: - | :code:`data` (:code:`torch.Tensor` or iterable of :code:`torch.Tensor`s): Tensor of shape :code:`[n_samples, n_1, ..., n_k]`. + | :code:`data` (:code:`torch.Tensor` or iterable of :code:`torch.Tensor`s): + Tensor of shape :code:`[n_samples, n_1, ..., n_k]`. | :code:`time` (:code:`int`): Length of Bernoulli spike train per input variable. Keyword arguments: @@ -57,7 +62,8 @@ def bernoulli_loader(data, time=None, **kwargs): Yields: - | (:code:`torch.Tensor`): Tensors of shape :code:`[time, n_1, ..., n_k]` of Bernoulli-distributed spikes. + | (:code:`torch.Tensor`): Tensors of shape :code:`[time, n_1, ..., n_k]` + of Bernoulli-distributed spikes. ''' # Setting kwargs. @@ -68,7 +74,8 @@ def bernoulli_loader(data, time=None, **kwargs): def poisson(datum, time, **kwargs): ''' - Generates Poisson-distributed spike trains based on input intensity. Inputs must be non-negative. + Generates Poisson-distributed spike trains based + on input intensity. Inputs must be non-negative. Inputs: @@ -77,7 +84,8 @@ def poisson(datum, time, **kwargs): Returns: - | (:code:`torch.Tensor`): Tensor of shape :code:`[time, n_1, ..., n_k]` of Poisson-distributed spikes. + | (:code:`torch.Tensor`): Tensor of shape :code:`[time, n_1, ..., n_k]` + of Poisson-distributed spikes. ''' datum = np.copy(datum) shape, size = datum.shape, datum.size @@ -107,12 +115,14 @@ def poisson_loader(data, time, **kwargs): Inputs: - | :code:`data` (:code:`torch.Tensor` or iterable of :code:`torch.Tensor`s): Tensor of shape :code:`[n_samples, n_1, ..., n_k]` + | :code:`data` (:code:`torch.Tensor` or iterable of :code:`torch.Tensor`s): + Tensor of shape :code:`[n_samples, n_1, ..., n_k]` | :code:`time` (:code:`int`): Length of Poisson spike train per input variable. Yields: - | (:code:`torch.Tensor`): Tensors of shape :code:`[time, n_1, ..., n_k]` of Poisson-distributed spikes. + | (:code:`torch.Tensor`): Tensors of shape :code:`[time, n_1, ..., n_k]` + of Poisson-distributed spikes. ''' for i in range(len(data)): yield poisson(data[i], time) # Encode datum as Poisson spike trains. @@ -120,16 +130,21 @@ def poisson_loader(data, time, **kwargs): def rank_order(datum, time, **kwargs): ''' - Encodes data via a rank order coding-like representation. One spike per neuron, temporally ordered by decreasing intensity. Inputs must be non-negative. + Encodes data via a rank order coding-like representation. One + spike per neuron, temporally ordered by decreasing intensity. + Inputs must be non-negative. Inputs: - | :code:`data` (:code:`torch.Tensor`): Tensor of shape :code:`[n_samples, n_1, ..., n_k]` - | :code:`time` (:code:`int`): Length of Poisson spike train per input variable. + | :code:`data` (:code:`torch.Tensor`): Tensor + of shape :code:`[n_samples, n_1, ..., n_k]` + | :code:`time` (:code:`int`): Length of rank + order-encoded spike train per input variable. Returns: - | (:code:`torch.Tensor`): Tensor of shape :code:`[time, n_1, ..., n_k]` of Poisson-distributed spikes. + | (:code:`torch.Tensor`): Tensor of shape + :code:`[time, n_1, ..., n_k]` of rank order-encoded spikes. ''' datum = np.copy(datum) shape, size = datum.shape, datum.size @@ -155,16 +170,20 @@ def rank_order(datum, time, **kwargs): def rank_order_loader(data, time, **kwargs): ''' - Lazily invokes :code:`bindsnet.encoding.rank_order` to iteratively encode a sequence of data. + Lazily invokes :code:`bindsnet.encoding.rank_order` + to iteratively encode a sequence of data. Inputs: - | :code:`data` (:code:`torch.Tensor` or iterable of :code:`torch.Tensor`s): Tensor of shape :code:`[n_samples, n_1, ..., n_k]` - | :code:`time` (:code:`int`): Length of Poisson spike train per input variable. + | :code:`data` (:code:`torch.Tensor` or iterable of :code:`torch.Tensor`s): + Tensor of shape :code:`[n_samples, n_1, ..., n_k]` + | :code:`time` (:code:`int`): Length of rank + order-encoded spike train per input variable. Yields: - | (:code:`torch.Tensor`): Tensors of shape :code:`[time, n_1, ..., n_k]` of rank order-encoded spikes. + | (:code:`torch.Tensor`): Tensors of shape :code:`[time, n_1, ..., n_k]` + of rank order-encoded spikes. ''' for i in range(len(data)): yield rank_order(data[i], time) # Encode datum as rank order-encoded spike trains. \ No newline at end of file diff --git a/bindsnet/environment/__init__.py b/bindsnet/environment/__init__.py index 8321a44d..99a2ad03 100644 --- a/bindsnet/environment/__init__.py +++ b/bindsnet/environment/__init__.py @@ -33,7 +33,8 @@ def __init__(self, dataset, train=True, time=350, **kwargs): self.intensity = kwargs.get('intensity', 1) self.max_prob = kwargs.get('max_prob', 1) - assert self.max_prob > 0 and self.max_prob <= 1, 'Maximum spiking probability must be in (0, 1].' + assert self.max_prob > 0 and self.max_prob <= 1, \ + 'Maximum spiking probability must be in (0, 1].' if train: self.data, self.labels = self.dataset.get_train() @@ -50,14 +51,14 @@ def step(self, a=None): Inputs: - | :code:`a` (:code:`None`): There is no interaction of the network with the MNIST dataset. + | :code:`a` (:code:`None`): There is no interaction of the network the dataset. Returns: - | :code:`obs` (:code:`torch.Tensor`): Observation from the environment (spike train-encoded MNIST digit). + | :code:`obs` (:code:`torch.Tensor`): Observation from the environment. | :code:`reward` (:code:`float`): Fixed to :code:`0`. | :code:`done` (:code:`bool`): Fixed to :code:`False`. - | :code:`info` (:code:`dict`): Contains label of MNIST digit. + | :code:`info` (:code:`dict`): Contains label of data item. ''' try: # Attempt to fetch the next observation. @@ -140,7 +141,8 @@ def __init__(self, name, **kwargs): # Keyword arguments. self.max_prob = kwargs.get('max_prob', 1) - assert self.max_prob > 0 and self.max_prob <= 1, 'Maximum spiking probability must be in (0, 1].' + assert self.max_prob > 0 and self.max_prob <= 1, \ + 'Maximum spiking probability must be in (0, 1].' def step(self, a): ''' diff --git a/bindsnet/evaluation/__init__.py b/bindsnet/evaluation/__init__.py index 43658787..3aa7d191 100644 --- a/bindsnet/evaluation/__init__.py +++ b/bindsnet/evaluation/__init__.py @@ -7,16 +7,21 @@ def assign_labels(spikes, labels, n_labels, rates=None, alpha=1.0): Inputs: - | :code:`spikes` (:code:`torch.Tensor`): Binary tensor of shape :code:`(n_samples, time, n_neurons)` of a single layer's spiking activity. - | :code:`labels` (:code:`torch.Tensor`): Vector of shape :code:`(n_samples,)` with data labels corresponding to spiking activity. + | :code:`spikes` (:code:`torch.Tensor`): Binary tensor of shape + :code:`(n_samples, time, n_neurons)` of a single layer's spiking activity. + | :code:`labels` (:code:`torch.Tensor`): Vector of shape :code:`(n_samples,)` + with data labels corresponding to spiking activity. | :code:`n_labels` (:code:`int`): The number of target labels in the data. - | :code:`rates` (:code:`torch.Tensor`): If passed, these represent spike rates from a previous :code:`assign_labels()` call. + | :code:`rates` (:code:`torch.Tensor`): If passed, these represent spike + rates from a previous :code:`assign_labels()` call. | :code:`alpha` (:code:`float`): Rate of decay of label assignments. Returns: - | (:code:`torch.Tensor`): Vector of shape :code:`(n_neurons,)` of neuron label assignments. - | (:code:`torch.Tensor`): Vector of shape :code:`(n_neurons, n_labels)` of proportions of firing activity per neuron, per data label. + | (:code:`torch.Tensor`): Vector of shape + :code:`(n_neurons,)` of neuron label assignments. + | (:code:`torch.Tensor`): Vector of shape :code:`(n_neurons, n_labels)` + of proportions of firing activity per neuron, per data label. ''' n_neurons = spikes.size(2) @@ -53,13 +58,16 @@ def all_activity(spikes, assignments, n_labels): Inputs: - | :code:`spikes` (:code:`torch.Tensor`): Binary tensor of shape :code:`(n_samples, time, n_neurons)` of a layer's spiking activity. - | :code:`assignments` (:code:`torch.Tensor`): A vector of shape :code:`(n_neurons,)` of neuron label assignments. + | :code:`spikes` (:code:`torch.Tensor`): Binary tensor of shape + :code:`(n_samples, time, n_neurons)` of a layer's spiking activity. + | :code:`assignments` (:code:`torch.Tensor`): A vector of shape + :code:`(n_neurons,)` of neuron label assignments. | :code:`n_labels` (:code:`int`): The number of target labels in the data. Returns: - | (:code:`torch.Tensor`): Predictions tensor of shape :code:`(n_samples,)` resulting from the "all activity" classification scheme. + | (:code:`torch.Tensor`): Predictions tensor of shape :code:`(n_samples,)` + resulting from the "all activity" classification scheme. ''' n_samples = spikes.size(0) @@ -88,18 +96,23 @@ def all_activity(spikes, assignments, n_labels): def proportion_weighting(spikes, assignments, proportions, n_labels): ''' - Classify data with the label with highest average spiking activity over all neurons, weighted by class-wise proportion.. + Classify data with the label with highest average spiking + activity over all neurons, weighted by class-wise proportion. Inputs: - | :code:`spikes` (:code:`torch.Tensor`): Binary tensor of shape :code:`(n_samples, time, n_neurons)` of a single layer's spiking activity. - | :code:`assignments` (:code:`torch.Tensor`): A vector of shape :code:`(n_neurons,)` of neuron label assignments. - | :code:`proportions` (torch.Tensor): A matrix of shape :code:`(n_neurons, n_labels)` giving the per-class proportions of neuron spiking activity. + | :code:`spikes` (:code:`torch.Tensor`): Binary tensor of shape + :code:`(n_samples, time, n_neurons)` of a single layer's spiking activity. + | :code:`assignments` (:code:`torch.Tensor`): A vector of shape + :code:`(n_neurons,)` of neuron label assignments. + | :code:`proportions` (torch.Tensor): A matrix of shape :code:`(n_neurons, n_labels)` + giving the per-class proportions of neuron spiking activity. | :code:`n_labels` (:code:`int`): The number of target labels in the data. Returns: - | (:code:`torch.Tensor`): Predictions tensor of shape :code:`(n_samples,)` resulting from the "proportion weighting" classification scheme. + | (:code:`torch.Tensor`): Predictions tensor of shapez:code:`(n_samples,)` + resulting from the "proportion weighting" classification scheme. ''' n_samples = spikes.size(0) diff --git a/bindsnet/learning/__init__.py b/bindsnet/learning/__init__.py index 8c5add49..72b4520d 100644 --- a/bindsnet/learning/__init__.py +++ b/bindsnet/learning/__init__.py @@ -9,7 +9,8 @@ def post_pre(conn, **kwargs): Inputs: - | :code:`conn` (:code:`bindsnet.network.topology.Connection`): An instance of class :code:`Connection`. + | :code:`conn` (:code:`bindsnet.network.topology.AbstractConnection`): + An instance of class :code:`AbstractAbstractConnectionConnection`. ''' if not 'kernel_size' in conn.__dict__: x_source, x_target = conn.source.x.unsqueeze(-1), conn.target.x.unsqueeze(0) @@ -26,10 +27,22 @@ def post_pre(conn, **kwargs): out_channels, _, kernel_height, kernel_width = conn.w.size() padding, stride = conn.padding, conn.stride - x_source = im2col_indices(conn.source.x, kernel_height, kernel_width, padding=padding, stride=stride) - x_target = conn.target.x.permute(1, 2, 3, 0).reshape(out_channels, -1) - s_source = im2col_indices(conn.source.s, kernel_height, kernel_width, padding=padding, stride=stride).float() - s_target = conn.target.s.permute(1, 2, 3, 0).reshape(out_channels, -1).float() + x_source = im2col_indices(conn.source.x, + kernel_height, + kernel_width, + padding=padding, + stride=stride) + + x_target = conn.target.x.permute(1, 2, 3, 0).reshape(out_channels, + -1) + s_source = im2col_indices(conn.source.s, + kernel_height, + kernel_width, + padding=padding, + stride=stride).float() + + s_target = conn.target.s.permute(1, 2, 3, 0).reshape(out_channels, + -1).float() # Post-synaptic. post = s_target @ x_source.t() @@ -48,7 +61,8 @@ def hebbian(conn, **kwargs): Inputs: - | :code:`conn` (:code:`bindsnet.network.topology.Connection`): An instance of class :code:`Connection`. + | :code:`conn` (:code:`bindsnet.network.topology.AbstractConnection`): + An instance of class :code:`AbstractConnection`. ''' if not 'kernel_size' in conn.__dict__: # Post-synaptic. @@ -62,10 +76,22 @@ def hebbian(conn, **kwargs): out_channels, _, kernel_height, kernel_width = conn.w.size() padding, stride = conn.padding, conn.stride - x_source = im2col_indices(conn.source.x, kernel_height, kernel_width, padding=padding, stride=stride) - x_target = conn.target.x.permute(1, 2, 3, 0).reshape(out_channels, -1) - s_source = im2col_indices(conn.source.s, kernel_height, kernel_width, padding=padding, stride=stride).float() - s_target = conn.target.s.permute(1, 2, 3, 0).reshape(out_channels, -1).float() + x_source = im2col_indices(conn.source.x, + kernel_height, + kernel_width, + padding=padding, + stride=stride) + + x_target = conn.target.x.permute(1, 2, 3, 0).reshape(out_channels, + -1) + s_source = im2col_indices(conn.source.s, + kernel_height, + kernel_width, + padding=padding, + stride=stride).float() + + s_target = conn.target.s.permute(1, 2, 3, 0).reshape(out_channels, + -1).float() # Post-synaptic. post = (x_source @ s_target.t()).view(conn.w.size()) @@ -92,7 +118,8 @@ def m_stdp(conn, **kwargs): Inputs: - | :code:`conn` (:code:`bindsnet.network.topology.Connection`): An instance of class :code:`Connection`. + | :code:`conn` (:code:`bindsnet.network.topology.AbstractConnection`): + An instance of class :code:`AbstractConnection`. ''' # Parse keyword arguments. try: @@ -124,10 +151,22 @@ def m_stdp(conn, **kwargs): out_channels, _, kernel_height, kernel_width = conn.w.size() padding, stride = conn.padding, conn.stride - p_plus = a_plus * im2col_indices(conn.source.x, kernel_height, kernel_width, padding=padding, stride=stride) - p_minus = a_minus * conn.target.x.permute(1, 2, 3, 0).reshape(out_channels, -1) - pre_fire = im2col_indices(conn.source.s, kernel_height, kernel_width, padding=padding, stride=stride).float() - post_fire = conn.target.s.permute(1, 2, 3, 0).reshape(out_channels, -1).float() + p_plus = a_plus * im2col_indices(conn.source.x, + kernel_height, + kernel_width, + padding=padding, + stride=stride) + + p_minus = a_minus * conn.target.x.permute(1, 2, 3, 0).reshape(out_channels, + -1) + pre_fire = im2col_indices(conn.source.s, + kernel_height, + kernel_width, + padding=padding, + stride=stride).float() + + post_fire = conn.target.s.permute(1, 2, 3, 0).reshape(out_channels, + -1).float() # Post-synaptic. post = (p_plus @ post_fire.t()).view(conn.w.size()) @@ -146,7 +185,9 @@ def m_stdp(conn, **kwargs): conn.w += conn.nu * reward * eligibility # Bound weights. - conn.w = torch.clamp(conn.w, conn.wmin, conn.wmax) + conn.w = torch.clamp(conn.w, + conn.wmin, + conn.wmax) def m_stdp_et(conn, **kwargs): @@ -156,12 +197,13 @@ def m_stdp_et(conn, **kwargs): Inputs: - | :code:`conn` (:code:`bindsnet.network.topology.Connection`): An instance of class :code:`Connection`. + | :code:`conn` (:code:`bindsnet.network.topology.AbstractConnection`): + An instance of class :code:`AbstractConnection`. - | :code:`kwargs`: + | Keyword arguments: - :code:`a_plus` (:code:`int`): Learning rate (positive). - :code:`a_minus` (:code:`int`): Learning rate (negative). + | :code:`a_plus` (:code:`int`): Learning rate (positive). + | :code:`a_minus` (:code:`int`): Learning rate (negative). ''' if not 'kernel_size' in conn.__dict__: # Parse keyword arguments. @@ -182,7 +224,8 @@ def m_stdp_et(conn, **kwargs): post_fire = conn.target.s.float().unsqueeze(0) # Calculate value of eligibility trace. - conn.e_trace += -(conn.tc_e_trace * conn.e_trace) + (conn.p_plus * post_fire + pre_fire * conn.p_minus) + conn.e_trace += -(conn.tc_e_trace * conn.e_trace) + \ + conn.p_plus * post_fire + pre_fire * conn.p_minus # Compute weight update. conn.w += conn.nu * reward * conn.e_trace @@ -193,10 +236,22 @@ def m_stdp_et(conn, **kwargs): out_channels, _, kernel_height, kernel_width = conn.w.size() padding, stride = conn.padding, conn.stride - p_plus = a_plus * im2col_indices(conn.source.x, kernel_height, kernel_width, padding=padding, stride=stride) - p_minus = a_minus * conn.target.x.permute(1, 2, 3, 0).reshape(out_channels, -1) - pre_fire = im2col_indices(conn.source.s, kernel_height, kernel_width, padding=padding, stride=stride).float() - post_fire = conn.target.s.permute(1, 2, 3, 0).reshape(out_channels, -1).float() + p_plus = a_plus * im2col_indices(conn.source.x, + kernel_height, + kernel_width, + padding=padding, + stride=stride) + + p_minus = a_minus * conn.target.x.permute(1, 2, 3, 0).reshape(out_channels, + -1) + pre_fire = im2col_indices(conn.source.s, + kernel_height, + kernel_width, + padding=padding, + stride=stride).float() + + post_fire = conn.target.s.permute(1, 2, 3, 0).reshape(out_channels, + -1).float() # Post-synaptic. post = (p_plus @ post_fire.t()).view(conn.w.size()) @@ -215,4 +270,6 @@ def m_stdp_et(conn, **kwargs): conn.w += conn.nu * reward * conn.e_trace # Bound weights. - conn.w = torch.clamp(conn.w, conn.wmin, conn.wmax) + conn.w = torch.clamp(conn.w, + conn.wmin, + conn.wmax) diff --git a/bindsnet/network/monitors.py b/bindsnet/network/monitors.py index f86c0479..037e331e 100644 --- a/bindsnet/network/monitors.py +++ b/bindsnet/network/monitors.py @@ -31,7 +31,8 @@ def __init__(self, obj, state_vars, time=None): # If simulation time is specified, pre-allocate recordings in memory for speed. else: - self.recording = {var : torch.zeros(*self.obj.__dict__[var].size(), self.time) for var in self.state_vars} + self.recording = {var : torch.zeros(*self.obj.__dict__[var].size(), + self.time) for var in self.state_vars} def get(self, var): ''' @@ -44,7 +45,7 @@ def get(self, var): Returns: | (:code:`torch.Tensor`): Tensor of shape :code:`[n_1, ..., n_k, time]`, - where :code:`[n_1, ..., n_k]` refers to the shape of the recorded state variable. + where :code:`[n_1, ..., n_k]` is the shape of the recorded state variable. ''' return self.recording[var] @@ -56,23 +57,14 @@ def record(self): if self.time is None: for var in self.state_vars: data = self.obj.__dict__[var].view(-1, 1).float() - self.recording[var] = torch.cat([self.recording[var], data], -1) + self.recording[var] = torch.cat([self.recording[var], + data], + -1) else: for var in self.state_vars: data = self.obj.__dict__[var].unsqueeze(-1) - - if len(data.size()) - 1 == 1: - self.recording[var][:, self.i % self.time] = data.squeeze() - - elif len(data.size()) - 1 == 2: - self.recording[var][:, :, self.i % self.time] = data.squeeze() - - elif len(data.size()) - 1 == 3: - self.recording[var][:, :, :, self.i % self.time] = data.squeeze() - - elif len(data.size()) - 1 == 4: - self.recording[var][:, :, :, :, self.i % self.time] = data.squeeze() - + self.recording[var][..., self.i % self.time] = data.squeeze() + self.i += 1 @@ -86,7 +78,8 @@ def _reset(self): # If simulation time is specified, pre-allocate recordings in memory for speed. else: - self.recording = {var : torch.zeros(*self.obj.__dict__[var].size(), self.time) for var in self.state_vars} + self.recording = {var : torch.zeros(*self.obj.__dict__[var].size(), + self.time) for var in self.state_vars} self.i = 0 @@ -142,11 +135,13 @@ def __init__(self, network, layers=None, connections=None, state_vars=['v', 's', for v in self.state_vars: for l in self.layers: if v in self.network.layers[l].__dict__: - self.recording[l][v] = torch.zeros(*self.network.layers[l].__dict__[v].size(), self.time) + self.recording[l][v] = torch.zeros(*self.network.layers[l].__dict__[v].size(), + self.time) for c in self.connections: if v in self.network.connections[c].__dict__: - self.recording[c][v] = torch.zeros(*self.network.connections[c].__dict__[v].size(), self.time) + self.recording[c][v] = torch.zeros(*self.network.connections[c].__dict__[v].size(), + self.time) def get(self): ''' @@ -163,28 +158,32 @@ def record(self): Appends the current value of the recorded state variables to the recording. ''' if self.time is None: - for var in self.state_vars: - for layer in self.layers: - if var in self.network.layers[layer].__dict__: - data = self.network.layers[layer].__dict__[var].unsqueeze(-1).float() - self.recording[layer][var] = torch.cat([self.recording[layer][var], data], -1) + for v in self.state_vars: + for l in self.layers: + if v in self.network.layers[l].__dict__: + data = self.network.layers[l].__dict__[v].unsqueeze(-1).float() + self.recording[l][v] = torch.cat([self.recording[l][v], + data], + -1) - for connection in self.connections: - if var in self.network.connections[connection].__dict__: - data = self.network.connections[connection].__dict__[var].unsqueeze(-1) - self.recording[connection][var] = torch.cat([self.recording[connection][var], data], -1) + for c in self.connections: + if v in self.network.connections[c].__dict__: + data = self.network.connections[c].__dict__[v].unsqueeze(-1) + self.recording[c][v] = torch.cat([self.recording[c][v], + data], + -1) else: - for var in self.state_vars: - for layer in self.layers: - if var in self.network.layers[layer].__dict__: - data = self.network.layers[layer].__dict__[var].float() - self.recording[layer][var][:, self.i % self.time] = data - - for connection in self.connections: - if var in self.network.connections[connection].__dict__: - data = self.network.connections[connection].__dict__[var] - self.recording[connection][var][:, :, self.i % self.time] = data + for v in self.state_vars: + for l in self.layers: + if v in self.network.layers[l].__dict__: + data = self.network.layers[l].__dict__[v].float() + self.recording[l][v][..., self.i % self.time] = data + + for c in self.connections: + if v in self.network.connections[c].__dict__: + data = self.network.connections[c].__dict__[v] + self.recording[c][v][..., self.i % self.time] = data self.i += 1 @@ -203,11 +202,13 @@ def save(self, path, fmt='npz'): if fmt == 'npz': # Build a list of arrays to write to disk. arrays = {} - for obj in self.recording: - if type(obj) == tuple: - arrays.update({'_'.join(['-'.join(obj), var]) : self.recording[obj][var] for var in self.recording[obj]}) - elif type(obj) == str: - arrays.update({'_'.join([obj, var]) : self.recording[obj][var] for var in self.recording[obj]}) + for o in self.recording: + if type(o) == tuple: + k = '_'.join(['-'.join(o), v]) + arrays.update({k : self.recording[o][v] for v in self.recording[o]}) + elif type(o) == str: + k = '_'.join([o, v]) + arrays.update({k : self.recording[o][v] for v in self.recording[o]}) np.savez_compressed(path, **arrays) @@ -239,8 +240,10 @@ def _reset(self): for v in self.state_vars: for l in self.layers: if v in self.network.layers[l].__dict__: - self.recording[l][v] = torch.zeros(self.network.layers[l].n, self.time) + self.recording[l][v] = torch.zeros(self.network.layers[l].n, + self.time) for c in self.connections: if v in self.network.connections[c].__dict__: - self.recording[c][v] = torch.zeros(*self.network.connections[c].w.size(), self.time) + self.recording[c][v] = torch.zeros(*self.network.connections[c].w.size(), + self.time) diff --git a/bindsnet/network/nodes.py b/bindsnet/network/nodes.py index 50c609b6..443b399e 100644 --- a/bindsnet/network/nodes.py +++ b/bindsnet/network/nodes.py @@ -12,7 +12,8 @@ class Nodes(ABC): def __init__(self, n=None, shape=None, traces=False, trace_tc=5e-2): super().__init__() - assert not n is None or not shape is None, 'Must provide either no. of neurons or shape of nodes' + assert not n is None or not shape is None, \ + 'Must provide either no. of neurons or shape of nodes' if n is None: self.n = reduce(mul, shape) # No. of neurons product of shape. @@ -24,7 +25,8 @@ def __init__(self, n=None, shape=None, traces=False, trace_tc=5e-2): else: self.shape = shape # Shape is passed in as an argument. - assert self.n == reduce(mul, self.shape), 'No. of neurons and shape do not match' + assert self.n == reduce(mul, self.shape), \ + 'No. of neurons and shape do not match' self.traces = traces # Whether to record synpatic traces. self.s = torch.zeros(self.shape).byte() # Spike occurences. @@ -114,8 +116,8 @@ def __init__(self, n=None, shape=None, traces=False, thresh=1.0, trace_tc=5e-2): ''' super().__init__(n, shape, traces, trace_tc) - self.thresh = thresh # Spike threshold voltage. - self.v = torch.zeros(self.shape) # Neuron voltages. + self.thresh = thresh # Spike threshold voltage. + self.v = torch.zeros(self.shape) # Neuron voltages. def step(self, inpts, dt): ''' diff --git a/bindsnet/network/topology.py b/bindsnet/network/topology.py index 028a69c6..232f2dc9 100644 --- a/bindsnet/network/topology.py +++ b/bindsnet/network/topology.py @@ -126,7 +126,7 @@ def __init__(self, source, target, nu=1e-2, nu_pre=1e-4, nu_post=1e-2, **kwargs) self.w = self.wmin + torch.rand(*source.shape, *target.shape) * (self.wmin - self.wmin) else: if torch.max(self.w) > self.wmax or torch.min(self.w) < self.wmin: - warnings.warn('Weight matrix will be clamped between [%f, %f]; values may be biased to interval values.' % (self.wmin, self.wmax)) + warnings.warn(f'Weight matrix will be clamped between [{self.wmin}, {self.wmax}]') self.w = torch.clamp(self.w, self.wmin, self.wmax) def compute(self, s): @@ -206,22 +206,30 @@ def __init__(self, source, target, kernel_size, stride=1, padding=0, self.padding = _pair(padding) self.dilation = _pair(dilation) - assert source.shape[0] == target.shape[0], 'Minibatch size not equal across source and target populations' + assert source.shape[0] == target.shape[0], 'Minibatch size not equal across source and target' minibatch = source.shape[0] self.in_channels, input_height, input_width = source.shape[1], source.shape[2], source.shape[3] self.out_channels, output_height, output_width = target.shape[1], target.shape[2], target.shape[3] - error_message = 'Target dimensionality must be (minibatch, out_channels, \ + error = 'Target dimensionality must be (minibatch, out_channels, \ (input_height - filter_height + 2 * padding_height) / stride_height + 1, \ (input_width - filter_width + 2 * padding_width) / stride_width + 1' assert tuple(target.shape) == (minibatch, self.out_channels, - (input_height - self.kernel_size[0] + 2 * self.padding[0]) / self.stride[0] + 1, - (input_width - self.kernel_size[1] + 2 * self.padding[1]) / self.stride[1] + 1), error_message + (input_height - self.kernel_size[0] + \ + 2 * self.padding[0]) / self.stride[0] + 1, + (input_width - self.kernel_size[1] + \ + 2 * self.padding[1]) / self.stride[1] + 1), error - self.w = kwargs.get('w', torch.rand(self.out_channels, self.in_channels, self.kernel_size[0], self.kernel_size[1])) - self.w = torch.clamp(self.w, self.wmin, self.wmax) + self.w = kwargs.get('w', torch.rand(self.out_channels, + self.in_channels, + self.kernel_size[0], + self.kernel_size[1])) + + self.w = torch.clamp(self.w, + self.wmin, + self.wmax) def compute(self, s): ''' @@ -231,7 +239,11 @@ def compute(self, s): | :code:`s` (:code:`torch.Tensor`): Incoming spikes. ''' - return F.conv2d(s.float(), self.w, stride=self.stride, padding=self.padding, dilation=self.dilation) + return F.conv2d(s.float(), + self.w, + stride=self.stride, + padding=self.padding, + dilation=self.dilation) def update(self, **kwargs): ''' @@ -245,7 +257,9 @@ def normalize(self): ''' if self.norm is not None: shape = self.w.size() - self.w = self.w.view(self.w.size(0), self.w.size(2) * self.w.size(3)) + self.w = self.w.view(self.w.size(0), + self.w.size(2) * self.w.size(3)) + for fltr in range(self.w.size(0)): self.w[fltr] *= self.norm / self.w[fltr].sum(0) diff --git a/bindsnet/pipeline/__init__.py b/bindsnet/pipeline/__init__.py index 97ab584d..6d6c9041 100644 --- a/bindsnet/pipeline/__init__.py +++ b/bindsnet/pipeline/__init__.py @@ -60,26 +60,27 @@ def __init__(self, network, environment, encoding=bernoulli, action_function=Non self.history_length = kwargs.get('history_length', None) self.render_interval = kwargs.get('render_interval', None) - # Make sure inputs are valid. - assert type(self.time == int and self.time >= 1), "Invalid input %d; Time \ - presented to the network cannot be \ - 0 or negative and must be an integer"%(self.time) - assert self.delta >= 1, "Invalid input %d; Delta cannot be 0 or negative"%(self.delta) - if self.output is not None: - assert self.output in self.network.layers, "Layer '%s' not found inside network"%(self.output) - if self.history_length is not None and self.delta is not None: - self.history = {i : torch.Tensor() for i in range(1, self.history_length * self.delta + 1, self.delta)} + self.history = {i : torch.Tensor() for i in range(1, + self.history_length * self.delta + 1, + self.delta)} else: self.history = {} if self.plot_interval is not None: - for layer in self.network.layers: - self.network.add_monitor(Monitor(self.network.layers[layer], 's', self.plot_interval * self.time), name='%s_spikes' % layer) - if 'v' in self.network.layers[layer].__dict__: - self.network.add_monitor(Monitor(self.network.layers[layer], 'v', self.plot_interval * self.time), name='%s_voltages' % layer) + for l in self.network.layers: + self.network.add_monitor(Monitor(self.network.layers[l], + 's', + self.plot_interval * self.time), + name=f'{l}_spikes') + + if 'v' in self.network.layers[l].__dict__: + self.network.add_monitor(Monitor(self.network.layers[l], + 'v', + self.plot_interval * self.time), + name=f'{l}_voltages') - self.spike_record = {layer : torch.ByteTensor() for layer in self.network.layers} + self.spike_record = {l : torch.ByteTensor() for l in self.network.layers} self.set_spike_data() self.plot_data() @@ -93,16 +94,16 @@ def set_spike_data(self): ''' Get the spike data from all layers in the pipeline's network. ''' - self.spike_record = {layer : self.network.monitors['%s_spikes' % layer].get('s') for layer in self.network.layers} + self.spike_record = {l : self.network.monitors[f'{l}_spikes'].get('s') for l in self.network.layers} def set_voltage_data(self): ''' Get the voltage data from all applicable layers in the pipeline's network. ''' self.voltage_record = {} - for layer in self.network.layers: - if 'v' in self.network.layers[layer].__dict__: - self.voltage_record[layer] = self.network.monitors['%s_voltages' % layer].get('v') + for l in self.network.layers: + if 'v' in self.network.layers[l].__dict__: + self.voltage_record[l] = self.network.monitors[f'{l}_voltages'].get('v') def step(self, **kwargs): ''' @@ -111,11 +112,11 @@ def step(self, **kwargs): clamp = kwargs.get('clamp', {}) if self.print_interval is not None and self.iteration % self.print_interval == 0: - print('Iteration: %d (Time: %.4f)' % (self.iteration, time.time() - self.clock)) + print(f'Iteration: {self.iteration} (Time: {time.time() - self.clock:.4f})') self.clock = time.time() if self.save_interval is not None and self.iteration % self.save_interval == 0: - print('Saving network to %s' % self.save_dir) + print(f'Saving network to {self.save_dir}') self.network.save(self.save_dir) # Render game. @@ -234,4 +235,4 @@ def _reset(self): self.env.reset() self.network._reset() self.iteration = 0 - self.history = self.history = {i: torch.Tensor() for i in self.history} + self.history = {i: torch.Tensor() for i in self.history} diff --git a/bindsnet/pipeline/action.py b/bindsnet/pipeline/action.py index ee2bfef1..9d81b720 100644 --- a/bindsnet/pipeline/action.py +++ b/bindsnet/pipeline/action.py @@ -8,18 +8,20 @@ def select_multinomial(pipeline, **kwargs): Inputs: - | :code:`pipeline` (:code:`bindsnet.pipeline.Pipeline`): Pipeline with environment that accepts feedback in the form of actions. + | :code:`pipeline` (:code:`bindsnet.pipeline.Pipeline`): Pipeline + with environment that has an integer action space. Returns: - | (:code:`int`): Number indicating the desired action from the action space. + | (:code:`int`): Integer indicating an action from the action space. ''' try: output = kwargs['output'] except KeyError: - raise KeyError('select_action requires an output layer of size a multiple of the action space.') + raise KeyError('select_multinomial() requires an "output" layer argument.') - assert pipeline.network.layers[output].n % pipeline.env.action_space.n == 0, 'Output layer size not equal to size of action space.' + assert pipeline.network.layers[output].n % pipeline.env.action_space.n == 0, \ + 'Output layer size not equal to size of action space.' pop_size = int(pipeline.network.layers[output].n / pipeline.env.action_space.n) @@ -38,11 +40,12 @@ def select_multinomial(pipeline, **kwargs): def select_softmax(pipeline, **kwargs): ''' - Selects an action using softmax probability function based on spiking activity from a network layer. + Selects an action using softmax function based on spiking from a network layer. Inputs: - | :code:`pipeline` (:code:`bindsnet.pipeline.Pipeline`): Pipeline with environment that accepts feedback in the form of actions. + | :code:`pipeline` (:code:`bindsnet.pipeline.Pipeline`): Pipeline + with environment that accepts feedback in the form of actions. Returns: @@ -51,9 +54,10 @@ def select_softmax(pipeline, **kwargs): try: output = kwargs['output'] except KeyError: - raise KeyError('select_action requires an output layer of size equal to the action space.') + raise KeyError('select_softmax() requires an "output" layer argument.') - assert pipeline.network.layers[output].n == pipeline.env.action_space.n, 'Output layer size not equal to size of action space.' + assert pipeline.network.layers[output].n == pipeline.env.action_space.n, \ + 'Output layer size not equal to size of action space.' # Sum of previous iterations' spikes (Not yet implemented) spikes = pipeline.network.layers[output].s @@ -73,7 +77,8 @@ def select_random(pipeline, **kwargs): Inputs: - | :code:`pipeline` (:code:`bindsnet.pipeline.Pipeline`): Pipeline with environment that accepts feedback in the form of actions. + | :code:`pipeline` (:code:`bindsnet.pipeline.Pipeline`): Pipeline + with environment that accepts feedback in the form of actions. Returns: