Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add 3D blurpooling and code cleanup for 1D and 2D blurpooling #39

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
183 changes: 103 additions & 80 deletions antialiased_cnns/blurpool.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,108 +10,131 @@
import torch.nn as nn
import torch.nn.functional as F
from IPython import embed
from functools import partial

class BlurPool(nn.Module):
def __init__(self, channels, pad_type='reflect', filt_size=4, stride=2, pad_off=0):
super(BlurPool, self).__init__()
self.filt_size = filt_size
self.pad_off = pad_off
self.pad_sizes = [int(1.*(filt_size-1)/2), int(np.ceil(1.*(filt_size-1)/2)), int(1.*(filt_size-1)/2), int(np.ceil(1.*(filt_size-1)/2))]
self.pad_sizes = [pad_size+pad_off for pad_size in self.pad_sizes]
self.stride = stride
self.off = int((self.stride-1)/2.)
self.channels = channels

if(self.filt_size==1):
a = np.array([1.,])
elif(self.filt_size==2):
a = np.array([1., 1.])
elif(self.filt_size==3):
a = np.array([1., 2., 1.])
elif(self.filt_size==4):
a = np.array([1., 3., 3., 1.])
elif(self.filt_size==5):
a = np.array([1., 4., 6., 4., 1.])
elif(self.filt_size==6):
a = np.array([1., 5., 10., 10., 5., 1.])
elif(self.filt_size==7):
a = np.array([1., 6., 15., 20., 15., 6., 1.])
class ZeroPad1d(torch.nn.modules.padding.ConstantPad1d):
def __init__(self, padding):
super(ZeroPad1d, self).__init__(padding, 0.)

filt = torch.Tensor(a[:,None]*a[None,:])
filt = filt/torch.sum(filt)
self.register_buffer('filt', filt[None,None,:,:].repeat((self.channels,1,1,1)))

self.pad = get_pad_layer(pad_type)(self.pad_sizes)
class ZeroPad3d(torch.nn.modules.padding.ConstantPad3d):
def __init__(self, padding):
super(ZeroPad3d, self).__init__(padding, 0.)

def forward(self, inp):
if(self.filt_size==1):
if(self.pad_off==0):
return inp[:,:,::self.stride,::self.stride]
else:
return self.pad(inp)[:,:,::self.stride,::self.stride]
else:
return F.conv2d(self.pad(inp), self.filt, stride=self.stride, groups=inp.shape[1])

def get_pad_layer(pad_type):
if(pad_type in ['refl','reflect']):
PadLayer = nn.ReflectionPad2d
elif(pad_type in ['repl','replicate']):
PadLayer = nn.ReplicationPad2d
elif(pad_type=='zero'):
PadLayer = nn.ZeroPad2d
else:
print('Pad type [%s] not recognized'%pad_type)
return PadLayer

class BlurPool1D(nn.Module):
def __init__(self, channels, pad_type='reflect', filt_size=3, stride=2, pad_off=0):
super(BlurPool1D, self).__init__()

def get_padding_layer(pad_type, dim):
pad_layer_dict = {
'reflect': {
1: nn.ReflectionPad1d,
2: nn.ReflectionPad2d,
},
'replicate': {
1: nn.ReplicationPad1d,
2: nn.ReplicationPad2d,
3: nn.ReplicationPad3d
},
'zero': {
1: ZeroPad1d,
2: nn.ZeroPad2d,
3: ZeroPad3d
}
}
pad_layer_dict['refl'] = pad_layer_dict['reflect']
pad_layer_dict['repl'] = pad_layer_dict['replicate']
try:
return pad_layer_dict[pad_type][dim]
except KeyError:
raise NotImplementedError


class BlurPoolND(nn.Module):

def __init__(self, channels, pad_type='reflect', filt_size=4, stride=2, pad_off=0, dims=2):

super(BlurPoolND, self).__init__()
self.filt_size = filt_size
self.pad_off = pad_off
self.pad_sizes = [int(1. * (filt_size - 1) / 2), int(np.ceil(1. * (filt_size - 1) / 2))]
self.pad_sizes = [int(1. * (filt_size - 1) / 2), int(np.ceil(1. * (filt_size - 1) / 2))] * dims
self.pad_sizes = [pad_size + pad_off for pad_size in self.pad_sizes]
self.stride = stride
self.off = int((self.stride - 1) / 2.)
self.channels = channels

# print('Filter size [%i]' % filt_size)
if(self.filt_size == 1):
a = None
if self.filt_size == 1:
a = np.array([1., ])
elif(self.filt_size == 2):
elif self.filt_size == 2:
a = np.array([1., 1.])
elif(self.filt_size == 3):
elif self.filt_size == 3:
a = np.array([1., 2., 1.])
elif(self.filt_size == 4):
elif self.filt_size == 4:
a = np.array([1., 3., 3., 1.])
elif(self.filt_size == 5):
elif self.filt_size == 5:
a = np.array([1., 4., 6., 4., 1.])
elif(self.filt_size == 6):
elif self.filt_size == 6:
a = np.array([1., 5., 10., 10., 5., 1.])
elif(self.filt_size == 7):
elif self.filt_size == 7:
a = np.array([1., 6., 15., 20., 15., 6., 1.])

filt = torch.Tensor(a)
filt = create_filter(a, dims)
filt = filt / torch.sum(filt)
self.register_buffer('filt', filt[None, None, :].repeat((self.channels, 1, 1)))

self.pad = get_pad_layer_1d(pad_type)(self.pad_sizes)
self.register_buffer('filt', get_filter_for_registration(filt, self.channels, dims))
self.pad = get_padding_layer(pad_type, dims)(self.pad_sizes)
self.identity_inference = get_identity_inference_func(dims)
self.conv_func = get_conv_func(dims)

def forward(self, inp):
if(self.filt_size == 1):
if(self.pad_off == 0):
return inp[:, :, ::self.stride]
if self.filt_size == 1:
if self.pad_off == 0:
return self.identity_inference(inp, self.stride)
else:
return self.pad(inp)[:, :, ::self.stride]
return self.identity_inference(self.pad(inp), self.stride)
else:
return F.conv1d(self.pad(inp), self.filt, stride=self.stride, groups=inp.shape[1])

def get_pad_layer_1d(pad_type):
if(pad_type in ['refl', 'reflect']):
PadLayer = nn.ReflectionPad1d
elif(pad_type in ['repl', 'replicate']):
PadLayer = nn.ReplicationPad1d
elif(pad_type == 'zero'):
PadLayer = nn.ZeroPad1d
else:
print('Pad type [%s] not recognized' % pad_type)
return PadLayer
return self.conv_func(self.pad(inp), self.filt, stride=self.stride, groups=inp.shape[1])


def create_filter(a, dims):
if dims == 1:
return torch.Tensor(a)
elif dims == 2:
return torch.Tensor(a[:, None] * a[None, :])
elif dims == 3:
b = a[:, None] * a[None, :]
return torch.Tensor(a[:, None, None] * b[None, :, :])
return NotImplementedError


def get_conv_func(dims):
if dims == 1:
return F.conv1d
elif dims == 2:
return F.conv2d
elif dims == 3:
return F.conv3d
return NotImplementedError


def get_filter_for_registration(filt, channels, dims):
if dims == 1:
return filt[None, None, :].repeat((channels, 1, 1))
elif dims == 2:
return filt[None, None, :, :].repeat((channels, 1, 1, 1))
elif dims == 3:
return filt[None, None, :, :, :].repeat((channels, 1, 1, 1, 1))
return NotImplementedError


def get_identity_inference_func(dims):
if dims == 1:
return lambda x, s: x[:, :, ::s]
elif dims == 2:
return lambda x, s: x[:, :, ::s, ::s]
elif dims == 3:
return lambda x, s: x[:, :, ::s, ::s, ::s]
return NotImplementedError


BlurPool1D = partial(BlurPoolND, dims=1)
BlurPool = BlurPoolND
BlurPool3D = partial(BlurPoolND, dims=3)