Skip to content

Commit

Permalink
Initial commit, pushed into pypi
Browse files Browse the repository at this point in the history
  • Loading branch information
ixaxaar committed Oct 26, 2017
1 parent 397d7ee commit 90365bd
Show file tree
Hide file tree
Showing 9 changed files with 972 additions and 0 deletions.
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -16,3 +16,5 @@ __pycache__/
*.lang
*.log
.cache/
dist/
dnc.egg-info/
69 changes: 69 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
# Differentiable Neural Computer, for Pytorch

This is an implementation of [Differentiable Neural Computers](people.idsia.ch/~rupesh/rnnsymposium2016/slides/graves.pdf), described in the paper [Hybrid computing using a neural network with dynamic external memory, Graves et al.](www.nature.com/articles/nature20101)

## Install

```bash
pip install dnc
```

## Usage

**Parameters**:

| Argument | Default | Description |
| --- | --- | --- |
| input_size | None | Size of the input vectors |
| hidden_size | None | Size of hidden units |
| rnn_type | 'lstm' | Type of recurrent cells used in the controller |
| num_layers | 1 | Number of layers of recurrent units in the controller |
| bias | True | Bias |
| batch_first | True | Whether data is fed batch first |
| dropout | 0 | Dropout between layers in the controller (Not yet implemented) |
| bidirectional | False | If the controller is bidirectional (Not yet implemented) |
| nr_cells | 5 | Number of memory cells |
| read_heads | 2 | Number of read heads |
| cell_size | 10 | Size of each memory cell |
| nonlinearity | 'tanh' | If using 'rnn' as `rnn_type`, non-linearity of the RNNs |
| gpu_id | -1 | ID of the GPU, -1 for CPU |
| independent_linears | False | Whether to use independent linear units to derive interface vector |
| share_memory | True | Whether to share memory between controller layers |


Example usage:

```python
from dnc import DNC

rnn = DNC(
input_size=64,
hidden_size=128,
rnn_type='lstm',
num_layers=4,
nr_cells=100,
cell_size=32,
read_heads=4,
batch_first=True,
gpu_id=0
)

(controller_hidden, memory, read_vectors) = (None, None, None)

output, (controller_hidden, memory, read_vectors) = \
rnn(torch.randn(10, 4, 64), (controller_hidden, memory, read_vectors))
```

## Example copy task

The copy task, as descibed in the original paper, is included in the repo.

```
python ./copy_task.py -cuda 0
```

## General noteworthy stuff

1. DNCs converge with Adam and RMSProp learning rules, SGD generally causes them to diverge.
2. Using a large batch size (> 100, recommended 1000) prevents gradients from becoming `NaN`.

1 change: 1 addition & 0 deletions dnc/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
#!/usr/bin/env python3
166 changes: 166 additions & 0 deletions dnc/copy_task.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,166 @@
#!/usr/bin/env python3

import warnings
warnings.filterwarnings('ignore')

import numpy as np
import getopt
import sys
import os
import math
import time
import argparse

sys.path.insert(0, os.path.join('..', '..'))

import torch as T
from torch.autograd import Variable as var
import torch.nn.functional as F
import torch.optim as optim

from torch.nn.utils import clip_grad_norm

from dnc import DNC

parser = argparse.ArgumentParser(description='PyTorch Differentiable Neural Computer')
parser.add_argument('-input_size', type=int, default= 6, help='dimension of input feature')
parser.add_argument('-nhid', type=int, default=64, help='humber of hidden units of the inner nn')

parser.add_argument('-nlayer', type=int, default=2, help='number of layers')
parser.add_argument('-lr', type=float, default=1e-2, help='initial learning rate')
parser.add_argument('-clip', type=float, default=0.5, help='gradient clipping')

parser.add_argument('-batch_size', type=int, default=100, metavar='N', help='batch size')
parser.add_argument('-mem_size', type=int, default=16, help='memory dimension')
parser.add_argument('-mem_slot', type=int, default=15, help='number of memory slots')
parser.add_argument('-read_heads', type=int, default=1, help='number of read heads')

parser.add_argument('-sequence_max_length', type=int, default=4, metavar='N', help='sequence_max_length')
parser.add_argument('-cuda', type=int, default=-1, help='Cuda GPU ID, -1 for CPU')
parser.add_argument('-log-interval', type=int, default=200, metavar='N', help='report interval')

parser.add_argument('-iterations', type=int, default=100000, metavar='N', help='total number of iteration')
parser.add_argument('-summarize_freq', type=int, default=100, metavar='N', help='summarize frequency')
parser.add_argument('-check_freq', type=int, default=100, metavar='N', help='check point frequency')

args = parser.parse_args()
print(args)

if args.cuda != -1:
print('Using CUDA.')
T.manual_seed(1111)
else:
print('Using CPU.')


def llprint(message):
sys.stdout.write(message)
sys.stdout.flush()


def generate_data(batch_size, length, size, cuda=-1):

input_data = np.zeros((batch_size, 2 * length + 1, size), dtype=np.float32)
target_output = np.zeros((batch_size, 2 * length + 1, size), dtype=np.float32)

sequence = np.random.binomial(1, 0.5, (batch_size, length, size - 1))

input_data[:, :length, :size - 1] = sequence
input_data[:, length, -1] = 1 # the end symbol
target_output[:, length + 1:, :size - 1] = sequence

input_data = T.from_numpy(input_data)
target_output = T.from_numpy(target_output)
if cuda != -1:
input_data = input_data.cuda()
target_output = target_output.cuda()

return var(input_data), var(target_output)


def criterion(predictions, targets):
return T.mean(
-1 * F.logsigmoid(predictions) * (targets) - T.log(1 - F.sigmoid(predictions) + 1e-9) * (1 - targets)
)

if __name__ == '__main__':

dirname = os.path.dirname(__file__)
ckpts_dir = os.path.join(dirname, 'checkpoints')
if not os.path.isdir(ckpts_dir):
os.mkdir(ckpts_dir)

batch_size = args.batch_size
sequence_max_length = args.sequence_max_length
iterations = args.iterations
summarize_freq = args.summarize_freq
check_freq = args.check_freq

# input_size = output_size = args.input_size
mem_slot = args.mem_slot
mem_size = args.mem_size
read_heads = args.read_heads


# options, _ = getopt.getopt(sys.argv[1:], '', ['iterations='])

# for opt in options:
# if opt[0] == '-iterations':
# iterations = int(opt[1])

rnn = DNC(
input_size=args.input_size,
hidden_size=args.nhid,
rnn_type='lstm',
num_layers=args.nlayer,
nr_cells=mem_slot,
cell_size=mem_size,
read_heads=read_heads,
gpu_id=args.cuda
)

if args.cuda != -1:
rnn = rnn.cuda(args.cuda)

last_save_losses = []

optimizer = optim.Adam(rnn.parameters(), lr=args.lr)

for epoch in range(iterations + 1):
llprint("\rIteration {ep}/{tot}".format(ep=epoch, tot=iterations))
optimizer.zero_grad()

random_length = np.random.randint(1, sequence_max_length + 1)

input_data, target_output = generate_data(batch_size, random_length, args.input_size, args.cuda)
# input_data = input_data.transpose(0, 1).contiguous()
target_output = target_output.transpose(0, 1).contiguous()

output, _ = rnn(input_data, None)
output = output.transpose(0, 1)

loss = criterion((output), target_output)
# if np.isnan(loss.data.cpu().numpy()):
# llprint('\nGot nan loss, contine to jump the backward \n')

# apply_dict(locals())
loss.backward()

optimizer.step()
loss_value = loss.data[0]

summerize = (epoch % summarize_freq == 0)
take_checkpoint = (epoch != 0) and (epoch % check_freq == 0)

last_save_losses.append(loss_value)

if summerize:
llprint("\n\tAvg. Logistic Loss: %.4f\n" % (np.mean(last_save_losses)))
last_save_losses = []

if take_checkpoint:
llprint("\nSaving Checkpoint ... "),
check_ptr = os.path.join(ckpts_dir, 'step_{}.pth'.format(epoch))
cur_weights = rnn.state_dict()
T.save(cur_weights, check_ptr)
llprint("Done!\n")
Loading

0 comments on commit 90365bd

Please sign in to comment.