From 90365bd95504ba27436780ebdb4070836184e9ab Mon Sep 17 00:00:00 2001 From: ixaxaar Date: Thu, 26 Oct 2017 20:59:05 +0530 Subject: [PATCH] Initial commit, pushed into pypi --- .gitignore | 2 + README.md | 69 +++++++++++++ dnc/__init__.py | 1 + dnc/copy_task.py | 166 ++++++++++++++++++++++++++++++ dnc/dnc.py | 255 ++++++++++++++++++++++++++++++++++++++++++++++ dnc/memory.py | 256 +++++++++++++++++++++++++++++++++++++++++++++++ dnc/util.py | 154 ++++++++++++++++++++++++++++ setup.cfg | 2 + setup.py | 67 +++++++++++++ 9 files changed, 972 insertions(+) create mode 100644 README.md create mode 100644 dnc/__init__.py create mode 100644 dnc/copy_task.py create mode 100644 dnc/dnc.py create mode 100644 dnc/memory.py create mode 100644 dnc/util.py create mode 100644 setup.cfg create mode 100644 setup.py diff --git a/.gitignore b/.gitignore index 17b2dab..14b2842 100644 --- a/.gitignore +++ b/.gitignore @@ -16,3 +16,5 @@ __pycache__/ *.lang *.log .cache/ +dist/ +dnc.egg-info/ diff --git a/README.md b/README.md new file mode 100644 index 0000000..afea442 --- /dev/null +++ b/README.md @@ -0,0 +1,69 @@ +# Differentiable Neural Computer, for Pytorch + +This is an implementation of [Differentiable Neural Computers](people.idsia.ch/~rupesh/rnnsymposium2016/slides/graves.pdf), described in the paper [Hybrid computing using a neural network with dynamic external memory, Graves et al.](www.nature.com/articles/nature20101) + +## Install + +```bash +pip install dnc +``` + +## Usage + +**Parameters**: + +| Argument | Default | Description | +| --- | --- | --- | +| input_size | None | Size of the input vectors | +| hidden_size | None | Size of hidden units | +| rnn_type | 'lstm' | Type of recurrent cells used in the controller | +| num_layers | 1 | Number of layers of recurrent units in the controller | +| bias | True | Bias | +| batch_first | True | Whether data is fed batch first | +| dropout | 0 | Dropout between layers in the controller (Not yet implemented) | +| bidirectional | False | If the controller is bidirectional (Not yet implemented) | +| nr_cells | 5 | Number of memory cells | +| read_heads | 2 | Number of read heads | +| cell_size | 10 | Size of each memory cell | +| nonlinearity | 'tanh' | If using 'rnn' as `rnn_type`, non-linearity of the RNNs | +| gpu_id | -1 | ID of the GPU, -1 for CPU | +| independent_linears | False | Whether to use independent linear units to derive interface vector | +| share_memory | True | Whether to share memory between controller layers | + + +Example usage: + +```python +from dnc import DNC + +rnn = DNC( + input_size=64, + hidden_size=128, + rnn_type='lstm', + num_layers=4, + nr_cells=100, + cell_size=32, + read_heads=4, + batch_first=True, + gpu_id=0 +) + +(controller_hidden, memory, read_vectors) = (None, None, None) + +output, (controller_hidden, memory, read_vectors) = \ + rnn(torch.randn(10, 4, 64), (controller_hidden, memory, read_vectors)) +``` + +## Example copy task + +The copy task, as descibed in the original paper, is included in the repo. + +``` +python ./copy_task.py -cuda 0 +``` + +## General noteworthy stuff + +1. DNCs converge with Adam and RMSProp learning rules, SGD generally causes them to diverge. +2. Using a large batch size (> 100, recommended 1000) prevents gradients from becoming `NaN`. + diff --git a/dnc/__init__.py b/dnc/__init__.py new file mode 100644 index 0000000..5f7ce86 --- /dev/null +++ b/dnc/__init__.py @@ -0,0 +1 @@ +#!/usr/bin/env python3 \ No newline at end of file diff --git a/dnc/copy_task.py b/dnc/copy_task.py new file mode 100644 index 0000000..ac3ccdd --- /dev/null +++ b/dnc/copy_task.py @@ -0,0 +1,166 @@ +#!/usr/bin/env python3 + +import warnings +warnings.filterwarnings('ignore') + +import numpy as np +import getopt +import sys +import os +import math +import time +import argparse + +sys.path.insert(0, os.path.join('..', '..')) + +import torch as T +from torch.autograd import Variable as var +import torch.nn.functional as F +import torch.optim as optim + +from torch.nn.utils import clip_grad_norm + +from dnc import DNC + +parser = argparse.ArgumentParser(description='PyTorch Differentiable Neural Computer') +parser.add_argument('-input_size', type=int, default= 6, help='dimension of input feature') +parser.add_argument('-nhid', type=int, default=64, help='humber of hidden units of the inner nn') + +parser.add_argument('-nlayer', type=int, default=2, help='number of layers') +parser.add_argument('-lr', type=float, default=1e-2, help='initial learning rate') +parser.add_argument('-clip', type=float, default=0.5, help='gradient clipping') + +parser.add_argument('-batch_size', type=int, default=100, metavar='N', help='batch size') +parser.add_argument('-mem_size', type=int, default=16, help='memory dimension') +parser.add_argument('-mem_slot', type=int, default=15, help='number of memory slots') +parser.add_argument('-read_heads', type=int, default=1, help='number of read heads') + +parser.add_argument('-sequence_max_length', type=int, default=4, metavar='N', help='sequence_max_length') +parser.add_argument('-cuda', type=int, default=-1, help='Cuda GPU ID, -1 for CPU') +parser.add_argument('-log-interval', type=int, default=200, metavar='N', help='report interval') + +parser.add_argument('-iterations', type=int, default=100000, metavar='N', help='total number of iteration') +parser.add_argument('-summarize_freq', type=int, default=100, metavar='N', help='summarize frequency') +parser.add_argument('-check_freq', type=int, default=100, metavar='N', help='check point frequency') + +args = parser.parse_args() +print(args) + +if args.cuda != -1: + print('Using CUDA.') + T.manual_seed(1111) +else: + print('Using CPU.') + + +def llprint(message): + sys.stdout.write(message) + sys.stdout.flush() + + +def generate_data(batch_size, length, size, cuda=-1): + + input_data = np.zeros((batch_size, 2 * length + 1, size), dtype=np.float32) + target_output = np.zeros((batch_size, 2 * length + 1, size), dtype=np.float32) + + sequence = np.random.binomial(1, 0.5, (batch_size, length, size - 1)) + + input_data[:, :length, :size - 1] = sequence + input_data[:, length, -1] = 1 # the end symbol + target_output[:, length + 1:, :size - 1] = sequence + + input_data = T.from_numpy(input_data) + target_output = T.from_numpy(target_output) + if cuda != -1: + input_data = input_data.cuda() + target_output = target_output.cuda() + + return var(input_data), var(target_output) + + +def criterion(predictions, targets): + return T.mean( + -1 * F.logsigmoid(predictions) * (targets) - T.log(1 - F.sigmoid(predictions) + 1e-9) * (1 - targets) + ) + +if __name__ == '__main__': + + dirname = os.path.dirname(__file__) + ckpts_dir = os.path.join(dirname, 'checkpoints') + if not os.path.isdir(ckpts_dir): + os.mkdir(ckpts_dir) + + batch_size = args.batch_size + sequence_max_length = args.sequence_max_length + iterations = args.iterations + summarize_freq = args.summarize_freq + check_freq = args.check_freq + + # input_size = output_size = args.input_size + mem_slot = args.mem_slot + mem_size = args.mem_size + read_heads = args.read_heads + + + # options, _ = getopt.getopt(sys.argv[1:], '', ['iterations=']) + + # for opt in options: + # if opt[0] == '-iterations': + # iterations = int(opt[1]) + + rnn = DNC( + input_size=args.input_size, + hidden_size=args.nhid, + rnn_type='lstm', + num_layers=args.nlayer, + nr_cells=mem_slot, + cell_size=mem_size, + read_heads=read_heads, + gpu_id=args.cuda + ) + + if args.cuda != -1: + rnn = rnn.cuda(args.cuda) + + last_save_losses = [] + + optimizer = optim.Adam(rnn.parameters(), lr=args.lr) + + for epoch in range(iterations + 1): + llprint("\rIteration {ep}/{tot}".format(ep=epoch, tot=iterations)) + optimizer.zero_grad() + + random_length = np.random.randint(1, sequence_max_length + 1) + + input_data, target_output = generate_data(batch_size, random_length, args.input_size, args.cuda) + # input_data = input_data.transpose(0, 1).contiguous() + target_output = target_output.transpose(0, 1).contiguous() + + output, _ = rnn(input_data, None) + output = output.transpose(0, 1) + + loss = criterion((output), target_output) + # if np.isnan(loss.data.cpu().numpy()): + # llprint('\nGot nan loss, contine to jump the backward \n') + + # apply_dict(locals()) + loss.backward() + + optimizer.step() + loss_value = loss.data[0] + + summerize = (epoch % summarize_freq == 0) + take_checkpoint = (epoch != 0) and (epoch % check_freq == 0) + + last_save_losses.append(loss_value) + + if summerize: + llprint("\n\tAvg. Logistic Loss: %.4f\n" % (np.mean(last_save_losses))) + last_save_losses = [] + + if take_checkpoint: + llprint("\nSaving Checkpoint ... "), + check_ptr = os.path.join(ckpts_dir, 'step_{}.pth'.format(epoch)) + cur_weights = rnn.state_dict() + T.save(cur_weights, check_ptr) + llprint("Done!\n") diff --git a/dnc/dnc.py b/dnc/dnc.py new file mode 100644 index 0000000..7faf045 --- /dev/null +++ b/dnc/dnc.py @@ -0,0 +1,255 @@ +#!/usr/bin/env python3 + +import torch.nn as nn +import torch as T +from torch.autograd import Variable as var +import numpy as np + +from torch.nn.utils.rnn import pad_packed_sequence as pad +from torch.nn.utils.rnn import pack_padded_sequence as pack +from torch.nn.utils.rnn import PackedSequence + +from util import * +from memory import * + + +class DNC(nn.Module): + + def __init__( + self, + input_size, + hidden_size, + rnn_type='lstm', + num_layers=1, + bias=True, + batch_first=True, + dropout=0, + bidirectional=False, + nr_cells=5, + read_heads=2, + cell_size=10, + nonlinearity='tanh', + gpu_id=-1, + independent_linears=False, + share_memory=True + ): + super(DNC, self).__init__() + # todo: separate weights and RNNs for the interface and output vectors + + self.input_size = input_size + self.hidden_size = hidden_size + self.rnn_type = rnn_type + self.num_layers = num_layers + self.bias = bias + self.batch_first = batch_first + self.dropout = dropout + self.bidirectional = bidirectional + self.nr_cells = nr_cells + self.read_heads = read_heads + self.cell_size = cell_size + self.nonlinearity = nonlinearity + self.gpu_id = gpu_id + self.independent_linears = independent_linears + self.share_memory = share_memory + + self.w = self.cell_size + self.r = self.read_heads + + # input size of layer 0 + self.layer0_input_size = self.r * self.w + self.input_size + # input size of subsequent layers + self.layern_input_size = self.r * self.w + self.hidden_size + + self.interface_size = (self.w * self.r) + (3 * self.w) + (5 * self.r) + 3 + self.output_size = self.hidden_size + + self.rnns = [] + self.memories = [] + + for layer in range(self.num_layers): + # controllers for each layer + if self.rnn_type.lower() == 'rnn': + if layer == 0: + self.rnns.append(nn.RNNCell(self.layer0_input_size, self.output_size, bias=self.bias, nonlinearity=self.nonlinearity)) + else: + self.rnns.append(nn.RNNCell(self.layern_input_size, self.output_size, bias=self.bias, nonlinearity=self.nonlinearity)) + elif self.rnn_type.lower() == 'gru': + if layer == 0: + self.rnns.append(nn.GRUCell(self.layer0_input_size, self.output_size, bias=self.bias)) + else: + self.rnns.append(nn.GRUCell(self.layern_input_size, self.output_size, bias=self.bias)) + elif self.rnn_type.lower() == 'lstm': + # if layer == 0: + self.rnns.append(nn.LSTMCell(self.layer0_input_size, self.output_size, bias=self.bias)) + # else: + # self.rnns.append(nn.LSTMCell(self.layern_input_size, self.output_size, bias=self.bias)) + + # memories for each layer + if not self.share_memory: + self.memories.append( + Memory( + input_size=self.output_size, + mem_size=self.nr_cells, + cell_size=self.w, + read_heads=self.r, + gpu_id=self.gpu_id, + independent_linears=self.independent_linears + ) + ) + + # only one memory shared by all layers + if self.share_memory: + self.memories.append( + Memory( + input_size=self.output_size, + mem_size=self.nr_cells, + cell_size=self.w, + read_heads=self.r, + gpu_id=self.gpu_id, + independent_linears=self.independent_linears + ) + ) + + for layer in range(self.num_layers): + setattr(self, 'rnn_layer_' + str(layer), self.rnns[layer]) + if not self.share_memory: + setattr(self, 'rnn_layer_memory_' + str(layer), self.memories[layer]) + if self.share_memory: + setattr(self, 'rnn_layer_memory_shared', self.memories[0]) + + # final output layer + self.output_weights = nn.Linear(self.output_size, self.output_size) + self.mem_out = nn.Linear(self.layern_input_size, self.input_size) + self.dropout_layer = nn.Dropout(self.dropout) + + if self.gpu_id != -1: + [x.cuda(self.gpu_id) for x in self.rnns] + [x.cuda(self.gpu_id) for x in self.memories] + self.mem_out.cuda(self.gpu_id) + + def _init_hidden(self, hx, batch_size, reset_experience): + # create empty hidden states if not provided + if hx is None: + hx = (None, None, None) + (chx, mhx, last_read) = hx + + # initialize hidden state of the controller RNN + if chx is None: + chx = cuda(T.zeros(self.num_layers, batch_size, self.output_size), gpu_id=self.gpu_id) + if self.rnn_type.lower() == 'lstm': + chx = (chx, chx) + + # Last read vectors + if last_read is None: + last_read = cuda(T.zeros(batch_size, self.w * self.r), gpu_id=self.gpu_id) + + # memory states + if mhx is None: + if self.share_memory: + mhx = self.memories[0].reset(batch_size, erase=reset_experience) + else: + mhx = [m.reset(batch_size, erase=reset_experience) for m in self.memories] + else: + if self.share_memory: + mhx = self.memories[0].reset(batch_size, mhx, erase=reset_experience) + else: + mhx = [m.reset(batch_size, h, erase=reset_experience) for m, h in zip(self.memories, mhx)] + + return chx, mhx, last_read + + def _layer_forward(self, input, layer, hx=(None, None)): + (chx, mhx) = hx + max_length = len(input) + outs = [0] * max_length + read_vectors = [0] * max_length + + for time in range(max_length): + # pass through controller + # print('input[time]', input[time].size(), self.layer0_input_size, self.layern_input_size) + chx = self.rnns[layer](input[time], chx) + # the interface vector + ξ = chx[0] if self.rnn_type.lower() == 'lstm' else chx + # the output + out = self.output_weights(chx[0]) + + # pass through memory + if self.share_memory: + read_vecs, mhx = self.memories[0](ξ, mhx) + else: + read_vecs, mhx = self.memories[layer](ξ, mhx) + read_vectors[time] = read_vecs.view(-1, self.w * self.r) + + # get the final output for this time step + outs[time] = self.mem_out(T.cat([out, read_vectors[time]], 1)) + + return outs, read_vectors, (chx, mhx) + + def forward(self, input, hx=(None, None, None), reset_experience=False): + # handle packed data + is_packed = type(input) is PackedSequence + if is_packed: + input, lengths = pad(input) + max_length = lengths[0] + else: + max_length = input.size(1) if self.batch_first else input.size(0) + lengths = [input.size(1)] * max_length if self.batch_first else [input.size(0)] * max_length + + batch_size = input.size(0) if self.batch_first else input.size(1) + + # make the data batch-first + if not self.batch_first: + input = input.transpose(0, 1) + + controller_hidden, mem_hidden, last_read = self._init_hidden(hx, batch_size, reset_experience) + + # batched forward pass per element / word / etc + outputs = None + chxs = [] + read_vectors = [last_read] * max_length + # outs = [input[:, x, :] for x in range(max_length)] + outs = [T.cat([input[:, x, :], last_read], 1) for x in range(max_length)] + + # chx = [x[0] for x in controller_hidden] if self.rnn_type.lower() == 'lstm' else controller_hidden[0] + for layer in range(self.num_layers): + # this layer's hidden states + chx = [x[layer] for x in controller_hidden] if self.rnn_type.lower() == 'lstm' else controller_hidden[layer] + + m = mem_hidden if self.share_memory else mem_hidden[layer] + # pass through controller + outs, _, (chx, m) = self._layer_forward( + outs, + layer, + (chx, m) + ) + + # store the memory back (per layer or shared) + if self.share_memory: + mem_hidden = m + else: + mem_hidden[layer] = m + chxs.append(chx) + + if layer == self.num_layers - 1: + # final outputs + outputs = T.stack(outs, 1) + else: + # the controller output + read vectors go into next layer + outs = [T.cat([o, r], 1) for o, r in zip(outs, read_vectors)] + # outs = [o for o in outs] + + # final hidden values + if self.rnn_type.lower() == 'lstm': + h = T.stack([x[0] for x in chxs], 0) + c = T.stack([x[1] for x in chxs], 0) + controller_hidden = (h, c) + else: + controller_hidden = T.stack(chxs, 0) + + if not self.batch_first: + outputs = outputs.transpose(0, 1) + if is_packed: + outputs = pack(output, lengths) + + # apply_dict(locals()) + + return outputs, (controller_hidden, mem_hidden, read_vectors[-1]) diff --git a/dnc/memory.py b/dnc/memory.py new file mode 100644 index 0000000..f41a11c --- /dev/null +++ b/dnc/memory.py @@ -0,0 +1,256 @@ +#!/usr/bin/env python3 + +import torch.nn as nn +import torch as T +from torch.autograd import Variable as var +import torch.nn.functional as F +import numpy as np + +from util import * + + +class Memory(nn.Module): + + def __init__(self, input_size, mem_size=512, cell_size=32, read_heads=4, gpu_id=-1, independent_linears=True): + super(Memory, self).__init__() + + self.mem_size = mem_size + self.cell_size = cell_size + self.read_heads = read_heads + self.gpu_id = gpu_id + self.input_size = input_size + self.independent_linears = independent_linears + + m = self.mem_size + w = self.cell_size + r = self.read_heads + + if self.independent_linears: + self.read_keys_transform = nn.Linear(self.input_size, w * r) + self.read_strengths_transform = nn.Linear(self.input_size, r) + self.write_key_transform = nn.Linear(self.input_size, w) + self.write_strength_transform = nn.Linear(self.input_size, 1) + self.erase_vector_transform = nn.Linear(self.input_size, w) + self.write_vector_transform = nn.Linear(self.input_size, w) + self.free_gates_transform = nn.Linear(self.input_size, r) + self.allocation_gate_transform = nn.Linear(self.input_size, 1) + self.write_gate_transform = nn.Linear(self.input_size, 1) + self.read_modes_transform = nn.Linear(self.input_size, 3 * r) + else: + self.interface_size = (w * r) + (3 * w) + (5 * r) + 3 + self.interface_weights = nn.Linear(self.input_size, self.interface_size) + + self.I = cuda(1 - T.eye(m).unsqueeze(0), gpu_id=self.gpu_id) # (1 * n * n) + + def reset(self, batch_size=1, hidden=None, erase=True): + m = self.mem_size + w = self.cell_size + r = self.read_heads + b = batch_size + + if hidden is None: + return { + 'memory': cuda(T.zeros(b, m, w).fill_(0), gpu_id=self.gpu_id), + 'link_matrix': cuda(T.zeros(b, 1, m, m), gpu_id=self.gpu_id), + 'precedence': cuda(T.zeros(b, 1, m), gpu_id=self.gpu_id), + 'read_weights': cuda(T.zeros(b, r, m).fill_(0), gpu_id=self.gpu_id), + 'write_weights': cuda(T.zeros(b, 1, m).fill_(0), gpu_id=self.gpu_id), + 'usage_vector': cuda(T.zeros(b, m), gpu_id=self.gpu_id) + } + else: + hidden['memory'] = hidden['memory'].clone() + hidden['link_matrix'] = hidden['link_matrix'].clone() + hidden['precedence'] = hidden['precedence'].clone() + hidden['read_weights'] = hidden['read_weights'].clone() + hidden['write_weights'] = hidden['write_weights'].clone() + hidden['usage_vector'] = hidden['usage_vector'].clone() + + if erase: + hidden['memory'].data.fill_(δ) + hidden['link_matrix'].data.zero_() + hidden['precedence'].data.zero_() + hidden['read_weights'].data.fill_(δ) + hidden['write_weights'].data.fill_(δ) + hidden['usage_vector'].data.zero_() + return hidden + + def get_usage_vector(self, usage, free_gates, read_weights, write_weights): + # write_weights = write_weights.detach() # detach from the computation graph + usage = usage + (1 - usage) * (1 - T.prod(1 - write_weights, 1)) + ψ = T.prod(1 - free_gates.unsqueeze(2) * read_weights, 1) + return usage * ψ + + def allocate(self, usage, write_gate): + # ensure values are not too small prior to cumprod. + usage = δ + (1 - δ) * usage + # free list + sorted_usage, φ = T.topk(usage, self.mem_size, dim=1, largest=False) + # TODO: these are actually shifted cumprods, tensorflow has exclusive=True + # fix once pytorch issue is fixed + sorted_allocation_weights = (1 - sorted_usage) * fake_cumprod(sorted_usage, self.gpu_id).squeeze() + # construct the reverse sorting index https://stackoverflow.com/questions/2483696/undo-or-reverse-argsort-python + _, φ_rev = T.topk(φ, k=self.mem_size, dim=1, largest=False) + allocation_weights = sorted_allocation_weights.gather(1, φ.long()) + + # update usage after allocating + # usage += ((1 - usage) * write_gate * allocation_weights) + return allocation_weights.unsqueeze(1), usage + + def write_weighting(self, memory, write_content_weights, allocation_weights, write_gate, allocation_gate): + ag = allocation_gate.unsqueeze(-1) + wg = write_gate.unsqueeze(-1) + + return wg * (ag * allocation_weights + (1 - ag) * write_content_weights) + + def get_link_matrix(self, link_matrix, write_weights, precedence): + precedence = precedence.unsqueeze(2) + write_weights_i = write_weights.unsqueeze(3) + write_weights_j = write_weights.unsqueeze(2) + + prev_scale = 1 - write_weights_i - write_weights_j + new_link_matrix = write_weights_i * precedence + + link_matrix = prev_scale * link_matrix + new_link_matrix + # elaborate trick to delete diag elems + return self.I.expand_as(link_matrix) * link_matrix + + def update_precedence(self, precedence, write_weights): + return (1 - T.sum(write_weights, 2, keepdim=True)) * precedence + write_weights + + def write(self, write_key, write_vector, erase_vector, free_gates, read_strengths, write_strength, write_gate, allocation_gate, hidden): + # get current usage + hidden['usage_vector'] = self.get_usage_vector( + hidden['usage_vector'], + free_gates, + hidden['read_weights'], + hidden['write_weights'] + ) + + # lookup memory with write_key and write_strength + write_content_weights = self.content_weightings(hidden['memory'], write_key, write_strength) + + # get memory allocation + alloc, _ = self.allocate( + hidden['usage_vector'], + allocation_gate * write_gate + ) + + # get write weightings + hidden['write_weights'] = self.write_weighting( + hidden['memory'], + write_content_weights, + alloc, + write_gate, + allocation_gate + ) + + weighted_resets = hidden['write_weights'].unsqueeze(3) * erase_vector.unsqueeze(2) + reset_gate = T.prod(1 - weighted_resets, 1) + # Update memory + hidden['memory'] = hidden['memory'] * reset_gate + + hidden['memory'] = hidden['memory'] + \ + T.bmm(hidden['write_weights'].transpose(1, 2), write_vector) + + # update link_matrix + hidden['link_matrix'] = self.get_link_matrix( + hidden['link_matrix'], + hidden['write_weights'], + hidden['precedence'] + ) + hidden['precedence'] = self.update_precedence(hidden['precedence'], hidden['write_weights']) + + return hidden + + def content_weightings(self, memory, keys, strengths): + d = θ(memory, keys) + strengths = F.softplus(strengths).unsqueeze(2) + return σ(d * strengths, 2) + + def directional_weightings(self, link_matrix, read_weights): + rw = read_weights.unsqueeze(1) + + f = T.matmul(link_matrix, rw.transpose(2, 3)).transpose(2, 3) + b = T.matmul(rw, link_matrix) + return f.transpose(1, 2), b.transpose(1, 2) + + def read_weightings(self, memory, content_weights, link_matrix, read_modes, read_weights): + forward_weight, backward_weight = self.directional_weightings(link_matrix, read_weights) + + content_mode = read_modes[:, :, 2].contiguous().unsqueeze(2) * content_weights + backward_mode = T.sum(read_modes[:, :, 0:1].contiguous().unsqueeze(3) * backward_weight, 2) + forward_mode = T.sum(read_modes[:, :, 1:2].contiguous().unsqueeze(3) * forward_weight, 2) + + return backward_mode + content_mode + forward_mode + + def read_vectors(self, memory, read_weights): + return T.bmm(read_weights, memory) + + def read(self, read_keys, read_strengths, read_modes, hidden): + content_weights = self.content_weightings(hidden['memory'], read_keys, read_strengths) + + hidden['read_weights'] = self.read_weightings( + hidden['memory'], + content_weights, + hidden['link_matrix'], + read_modes, + hidden['read_weights'] + ) + read_vectors = self.read_vectors(hidden['memory'], hidden['read_weights']) + return read_vectors, hidden + + def forward(self, ξ, hidden): + + # ξ = ξ.detach() + m = self.mem_size + w = self.cell_size + r = self.read_heads + b = ξ.size()[0] + + if self.independent_linears: + # r read keys (b * r * w) + read_keys = self.read_keys_transform(ξ).view(b, r, w) + # r read strengths (b * r) + read_strengths = self.read_strengths_transform(ξ).view(b, r) + # write key (b * 1 * w) + write_key = self.write_key_transform(ξ).view(b, 1, w) + # write strength (b * 1) + write_strength = self.write_strength_transform(ξ).view(b, 1) + # erase vector (b * 1 * w) + erase_vector = F.sigmoid(self.erase_vector_transform(ξ).view(b, 1, w)) + # write vector (b * 1 * w) + write_vector = self.write_vector_transform(ξ).view(b, 1, w) + # r free gates (b * r) + free_gates = F.sigmoid(self.free_gates_transform(ξ).view(b, r)) + # allocation gate (b * 1) + allocation_gate = F.sigmoid(self.allocation_gate_transform(ξ).view(b, 1)) + # write gate (b * 1) + write_gate = F.sigmoid(self.write_gate_transform(ξ).view(b, 1)) + # read modes (b * r * 3) + read_modes = σ(self.read_modes_transform(ξ).view(b, r, 3), 1) + else: + ξ = self.interface_weights(ξ) + # r read keys (b * w * r) + read_keys = ξ[:, :r * w].contiguous().view(b, r, w) + # r read strengths (b * r) + read_strengths = 1 + F.relu(ξ[:, r * w:r * w + r].contiguous().view(b, r)) + # write key (b * w * 1) + write_key = ξ[:, r * w + r:r * w + r + w].contiguous().view(b, 1, w) + # write strength (b * 1) + write_strength = 1 + F.relu(ξ[:, r * w + r + w].contiguous()).view(b, 1) + # erase vector (b * w) + erase_vector = F.sigmoid(ξ[:, r * w + r + w + 1: r * w + r + 2 * w + 1].contiguous().view(b, 1, w)) + # write vector (b * w) + write_vector = ξ[:, r * w + r + 2 * w + 1: r * w + r + 3 * w + 1].contiguous().view(b, 1, w) + # r free gates (b * r) + free_gates = F.sigmoid(ξ[:, r * w + r + 3 * w + 1: r * w + 2 * r + 3 * w + 1].contiguous().view(b, r)) + # allocation gate (b * 1) + allocation_gate = F.sigmoid(ξ[:, r * w + 2 * r + 3 * w + 1].contiguous().unsqueeze(1).view(b, 1)) + # write gate (b * 1) + write_gate = F.sigmoid(ξ[:, r * w + 2 * r + 3 * w + 2].contiguous()).unsqueeze(1).view(b, 1) + # read modes (b * 3*r) + read_modes = σ(ξ[:, r * w + 2 * r + 3 * w + 2: r * w + 5 * r + 3 * w + 2].contiguous().view(b, r, 3), 1) + + hidden = self.write(write_key, write_vector, erase_vector, free_gates, + read_strengths, write_strength, write_gate, allocation_gate, hidden) + return self.read(read_keys, read_strengths, read_modes, hidden) diff --git a/dnc/util.py b/dnc/util.py new file mode 100644 index 0000000..b3cc895 --- /dev/null +++ b/dnc/util.py @@ -0,0 +1,154 @@ +#!/usr/bin/env python3 + +import torch.nn as nn +import torch as T +import torch.nn.functional as F +from torch.autograd import Variable as var +import numpy as np +import torch +from torch.autograd import Variable +import re +import string + + +def recursiveTrace(obj): + print(type(obj)) + if hasattr(obj, 'grad_fn'): + print(obj.grad_fn) + recursiveTrace(obj.grad_fn) + elif hasattr(obj, 'saved_variables'): + print(obj.requires_grad, len(obj.saved_tensors), len(obj.saved_variables)) + [print(v) for v in obj.saved_variables] + [recursiveTrace(v.grad_fn) for v in obj.saved_variables] + + +def cuda(x, grad=False, gpu_id=-1): + if gpu_id == -1: + return var(x, requires_grad=grad) + else: + return var(x.pin_memory(), requires_grad=grad).cuda(gpu_id, async=True) + + +def cudavec(x, grad=False, gpu_id=-1): + if gpu_id == -1: + return var(T.from_numpy(x), requires_grad=grad) + else: + return var(T.from_numpy(x).pin_memory(), requires_grad=grad).cuda(gpu_id, async=True) + + +def cudalong(x, grad=False, gpu_id=-1): + if gpu_id == -1: + return var(T.from_numpy(x.astype(np.long)), requires_grad=grad) + else: + return var(T.from_numpy(x.astype(np.long)).pin_memory(), requires_grad=grad).cuda(gpu_id, async=True) + + +def fake_cumprod(vb, gpu_id): + """ + args: + vb: [hei x wid] + -> NOTE: we are lazy here so now it only supports cumprod along wid + """ + # real_cumprod = torch.cumprod(vb.data, 1) + vb = vb.unsqueeze(0) + mul_mask_vb = Variable(torch.zeros(vb.size(2), vb.size(1), vb.size(2))).type_as(vb) + + if gpu_id != -1: + mul_mask_vb = mul_mask_vb.cuda(gpu_id) + + for i in range(vb.size(2)): + mul_mask_vb[i, :, :i + 1] = 1 + add_mask_vb = 1 - mul_mask_vb + vb = vb.expand_as(mul_mask_vb) * mul_mask_vb + add_mask_vb + # vb = torch.prod(vb, 2).transpose(0, 2) # 0.1.12 + vb = torch.prod(vb, 2, keepdim=True).transpose(0, 2) # 0.2.0 + # print(real_cumprod - vb.data) # NOTE: checked, ==0 + return vb + + +def θ(a, b, dimA=2, dimB=2, normBy=2): + """Batchwise Cosine distance + + Cosine distance + + Arguments: + a {Tensor} -- A 3D Tensor (b * m * w) + b {Tensor} -- A 3D Tensor (b * r * w) + + Keyword Arguments: + dimA {number} -- exponent value of the norm for `a` (default: {2}) + dimB {number} -- exponent value of the norm for `b` (default: {1}) + + Returns: + Tensor -- Batchwise cosine distance (b * r * m) + """ + a_norm = T.norm(a, normBy, dimA, keepdim=True).expand_as(a) + δ + b_norm = T.norm(b, normBy, dimB, keepdim=True).expand_as(b) + δ + + x = T.bmm(a, b.transpose(1, 2)).transpose(1, 2) / ( + T.bmm(a_norm, b_norm.transpose(1, 2)).transpose(1, 2) + δ) + # apply_dict(locals()) + return x + + +def σ(input, axis=1): + """Softmax on an axis + + Softmax on an axis + + Arguments: + input {Tensor} -- input Tensor + + Keyword Arguments: + axis {number} -- axis on which to take softmax on (default: {1}) + + Returns: + Tensor -- Softmax output Tensor + """ + input_size = input.size() + + trans_input = input.transpose(axis, len(input_size) - 1) + trans_size = trans_input.size() + + input_2d = trans_input.contiguous().view(-1, trans_size[-1]) + soft_max_2d = F.softmax(input_2d) + soft_max_nd = soft_max_2d.view(*trans_size) + return soft_max_nd.transpose(axis, len(input_size) - 1) + +δ = 1e-6 + + +def register_nan_checks(model): + def check_grad(module, grad_input, grad_output): + # print(module) you can add this to see that the hook is called + print('hook called for ' + str(type(module))) + if any(np.all(np.isnan(gi.data.cpu().numpy())) for gi in grad_input if gi is not None): + print('NaN gradient in grad_input ' + type(module).__name__) + + model.apply(lambda module: module.register_backward_hook(check_grad)) + + +def apply_dict(dic): + for k, v in dic.items(): + apply_var(v, k) + if isinstance(v, nn.Module): + key_list = [a for a in dir(v) if not a.startswith('__')] + for key in key_list: + apply_var(getattr(v, key), key) + for pk, pv in v._parameters.items(): + apply_var(pv, pk) + + +def apply_var(v, k): + if isinstance(v, Variable) and v.requires_grad: + v.register_hook(check_nan_gradient(k)) + + +def check_nan_gradient(name=''): + def f(tensor): + if np.isnan(T.mean(tensor).data.cpu().numpy()): + print('\nnan gradient of {} :'.format(name)) + # print(tensor) + # assert 0, 'nan gradient' + return tensor + return f diff --git a/setup.cfg b/setup.cfg new file mode 100644 index 0000000..68c61a2 --- /dev/null +++ b/setup.cfg @@ -0,0 +1,2 @@ +[bdist_wheel] +universal=0 \ No newline at end of file diff --git a/setup.py b/setup.py new file mode 100644 index 0000000..be0b5d4 --- /dev/null +++ b/setup.py @@ -0,0 +1,67 @@ +#!/usr/bin/env python3 + +"""A setuptools based setup module. +See: +https://packaging.python.org/en/latest/distributing.html +https://github.com/pypa/sampleproject +""" + +# Always prefer setuptools over distutils +from setuptools import setup, find_packages +# To use a consistent encoding +from codecs import open +from os import path + +here = path.abspath(path.dirname(__file__)) + +# Get the long description from the README file +with open(path.join(here, 'README.md'), encoding='utf-8') as f: + long_description = f.read() + +setup( + name='dnc', + + version='0.0.1', + + description='Differentiable Neural Computer, for Pytorch', + long_description=long_description, + + # The project's main homepage. + url='https://github.com/pypa/dnc', + + # Author details + author='Russi Chatterjee', + author_email='root@ixaxaar.in', + + # Choose your license + license='MIT', + + # See https://pypi.python.org/pypi?%3Aaction=list_classifiers + classifiers=[ + 'Development Status :: 3 - Alpha', + + 'Intended Audience :: Science/Research', + 'Topic :: Scientific/Engineering :: Artificial Intelligence', + + 'License :: OSI Approved :: MIT License', + + 'Programming Language :: Python :: 3', + 'Programming Language :: Python :: 3.3', + 'Programming Language :: Python :: 3.4', + 'Programming Language :: Python :: 3.5', + 'Programming Language :: Python :: 3.6', + ], + + keywords='differentiable neural computer dnc memory network', + + packages=find_packages(exclude=['contrib', 'docs', 'tests']), + + install_requires=['torch', 'numpy'], + + extras_require={ + 'dev': ['check-manifest'], + 'test': ['coverage'], + }, + + python_requires='>=3', +)