Skip to content

Commit

Permalink
Fix issue 91 where dropconnect was not parallel (#93)
Browse files Browse the repository at this point in the history
  • Loading branch information
Frédéric Branchaud-Charron authored Nov 3, 2020
1 parent f6a0352 commit a120df6
Show file tree
Hide file tree
Showing 6 changed files with 31 additions and 94 deletions.
2 changes: 1 addition & 1 deletion .circleci/config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ jobs:
pip install -e .
pip install -r test-requirements.txt
pip uninstall -y torch torchvision
pip install torch==1.4.0+cpu torchvision==0.5.0+cpu -f https://download.pytorch.org/whl/cpu/torch_stable.html --trusted-host download.pytorch.org
pip install torch==1.6.0+cpu torchvision==0.7.0+cpu -f https://download.pytorch.org/whl/cpu/torch_stable.html --trusted-host download.pytorch.org
- run:
name: Setup Environment Variables
Expand Down
2 changes: 1 addition & 1 deletion docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@
# The short X.Y version
version = ''
# The full version, including alpha/beta/rc tags
release = '1.2.0'
release = '1.2.1'

# -- General configuration ---------------------------------------------------

Expand Down
4 changes: 2 additions & 2 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@ pandas
Pillow>=6.2.0
scikit-learn
scipy
torch>=1.2.0
torchvision>=0.2.2
torch>=1.6.0
torchvision>=0.7.0
tqdm
h5py
structlog
Expand Down
50 changes: 17 additions & 33 deletions src/baal/bayesian/weight_drop.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
import copy
import warnings
from typing import List
import copy

from torch.nn import Parameter
import torch

Sequence = List[str]
Expand All @@ -15,32 +14,6 @@ def get_weight_drop_module(name: str, weight_dropout, **kwargs):
}[name](weight_dropout, **kwargs)


# Code from https://pytorchnlp.readthedocs.io/en/latest/_modules/torchnlp/nn/weight_drop.html
def _weight_drop(module, weights, dropout):
"""
Helper for `WeightDrop`.
"""

for name_w in weights:
w = getattr(module, name_w)
del module._parameters[name_w]
module.register_parameter(name_w + '_raw', Parameter(w))

original_module_forward = module.forward

def forward(*args, **kwargs):
for name_w in weights:
raw_w = getattr(module, name_w + '_raw')

# dropout should work in inference time as well
w = torch.nn.functional.dropout(raw_w, p=dropout, training=True)
setattr(module, name_w, w)

return original_module_forward(*args, **kwargs)

setattr(module, 'forward', forward)


class WeightDropLinear(torch.nn.Linear):
"""
Thanks to PytorchNLP for the initial implementation
Expand All @@ -57,8 +30,11 @@ def __init__(self, weight_dropout=0.0, **kwargs):
wanted = ['in_features', 'out_features']
kwargs = {k: v for k, v in kwargs.items() if k in wanted}
super().__init__(**kwargs)
weights = ['weight']
_weight_drop(self, weights, weight_dropout)
self._weight_dropout = weight_dropout

def forward(self, input):
w = torch.nn.functional.dropout(self.weight, p=self._weight_dropout, training=True)
return torch.nn.functional.linear(input, w, self.bias)


class WeightDropConv2d(torch.nn.Conv2d):
Expand All @@ -71,12 +47,17 @@ class WeightDropConv2d(torch.nn.Conv2d):
Args:
weight_dropout (float): The probability a weight will be dropped.
"""

def __init__(self, weight_dropout=0.0, **kwargs):
wanted = ['in_channels', 'out_channels', 'kernel_size', 'dilation', 'padding']
kwargs = {k: v for k, v in kwargs.items() if k in wanted}
super().__init__(**kwargs)
weights = ['weight']
_weight_drop(self, weights, weight_dropout)
self._weight_dropout = weight_dropout

def forward(self, input):
return self._conv_forward(input, torch.nn.functional.dropout(self.weight,
p=self._weight_dropout,
training=True))


def patch_module(module: torch.nn.Module,
Expand Down Expand Up @@ -145,8 +126,11 @@ class MCDropoutConnectModule(torch.nn.Module):
Name of layers to be replaced from ['Conv', 'Linear', 'LSTM', 'GRU'].
weight_dropout (float): The probability a weight will be dropped.
"""

def __init__(self, module: torch.nn.Module, layers: Sequence, weight_dropout=0.0):
super().__init__()
self.parent_module = module
_patch_layers(self.parent_module, layers, weight_dropout)
self.forward = self.parent_module.forward

def forward(self, x):
return self.parent_module(x)
2 changes: 1 addition & 1 deletion src/baal/version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "1.2.0"
__version__ = "1.2.1"
65 changes: 9 additions & 56 deletions tests/bayesian/dropconnect_test.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,14 @@
import warnings

import numpy as np
import pytest
import torch
from torch.utils.data import Dataset

from baal.bayesian.weight_drop import WeightDropLinear, WeightDropConv2d, \
patch_module, MCDropoutConnectModule
from baal.bayesian.weight_drop import patch_module


class DummyDataset(Dataset):
def __len__(self):
return 20

def __getitem__(self, item):
return torch.from_numpy(np.ones([3, 10, 10]) * item / 255.).float(),\
torch.FloatTensor([item % 2])


class DummyModel(torch.nn.Module):
class SimpleModel(torch.nn.Module):
def __init__(self):
super(DummyModel, self).__init__()
super(SimpleModel, self).__init__()
self.conv = torch.nn.Conv2d(3, 8, kernel_size=10)
self.relu = torch.nn.ReLU()
self.dropout = torch.nn.Dropout()
Expand All @@ -36,36 +24,18 @@ def forward(self, x):


@pytest.mark.parametrize("inplace", (True, False))
@pytest.mark.parametrize("layers", (['Linear'], ['Linear', 'Conv2d']))
@pytest.mark.parametrize("layers", (['Linear'], ['Linear', 'Conv2d'], ['Conv2d']))
def test_patch_module_changes_weights(inplace, layers):
test_module = torch.nn.Sequential(
torch.nn.Conv2d(3, 8, kernel_size=10),
torch.nn.ReLU(),
torch.nn.Dropout(p=0.5),
torch.nn.Linear(8, 1),
)

conv_w = list(test_module.modules())[1].weight.clone().detach().numpy()
linear_w = list(test_module.modules())[-1].weight.clone().detach().numpy()
test_module = SimpleModel()
test_module.eval()
simple_input = torch.randn(10, 3, 10, 10)
assert torch.allclose(test_module(simple_input), test_module(simple_input))

mc_test_module = patch_module(test_module, layers=layers, weight_dropout=0.2, inplace=inplace)

# objects should be the same if inplace is True and not otherwise:
assert (mc_test_module is test_module) == inplace

new_linear_w = list(mc_test_module.modules())[-1].weight_raw.clone().detach().numpy()
if layers == ['Linear']:
assert isinstance(list(mc_test_module.modules())[-1], WeightDropLinear)
assert isinstance(list(mc_test_module.modules())[1], torch.nn.Conv2d)
new_conv_w = list(mc_test_module.modules())[1].weight.clone().detach().numpy()
assert np.allclose(new_conv_w, conv_w)
assert not np.allclose(new_linear_w, linear_w)
else:
assert isinstance(list(mc_test_module.modules())[-1], WeightDropLinear)
assert isinstance(list(mc_test_module.modules())[1], WeightDropConv2d)
new_conv_w = list(mc_test_module.modules())[1].weight_raw.clone().detach().numpy()
assert not np.allclose(new_conv_w, conv_w)
assert not np.allclose(new_linear_w, linear_w)
assert not torch.allclose(mc_test_module(simple_input), mc_test_module(simple_input))

assert list(mc_test_module.modules())[3].p == 0

Expand All @@ -87,22 +57,5 @@ def test_patch_module_raise_warnings(inplace, layers):
assert "No layer was modified by patch_module" in str(w[-1].message)


def test_weight_change_after_forward_pass():
test_module = DummyModel()
dataset = DummyDataset()
mc_test_module = MCDropoutConnectModule(test_module, layers=['Linear'], weight_dropout=0.2)

assert not hasattr(list(test_module.modules())[-1], 'weight')
linear_w = list(test_module.modules())[-1].weight_raw.clone().detach().numpy()

input, _ = [torch.stack(v) for v in zip(*(dataset[0], dataset[1]))]
mc_test_module.eval()
out = mc_test_module(input)

assert hasattr(list(test_module.modules())[-1], 'weight')
new_linear_w = list(mc_test_module.modules())[-1].weight.clone().detach().numpy()
assert not np.allclose(new_linear_w, linear_w)


if __name__ == '__main__':
pytest.main()

0 comments on commit a120df6

Please sign in to comment.