From 3d6f02b6943c58b68c19c07bc26fad57492ff3bc Mon Sep 17 00:00:00 2001 From: Yawar Siddiqui Date: Fri, 11 Sep 2020 04:08:24 +0200 Subject: [PATCH] Add 3D blurpooling and code cleanup for 1D and 2D blurpooling --- antialiased_cnns/blurpool.py | 183 ++++++++++++++++++++--------------- 1 file changed, 103 insertions(+), 80 deletions(-) diff --git a/antialiased_cnns/blurpool.py b/antialiased_cnns/blurpool.py index fd09da3..d55d74f 100644 --- a/antialiased_cnns/blurpool.py +++ b/antialiased_cnns/blurpool.py @@ -10,108 +10,131 @@ import torch.nn as nn import torch.nn.functional as F from IPython import embed +from functools import partial -class BlurPool(nn.Module): - def __init__(self, channels, pad_type='reflect', filt_size=4, stride=2, pad_off=0): - super(BlurPool, self).__init__() - self.filt_size = filt_size - self.pad_off = pad_off - self.pad_sizes = [int(1.*(filt_size-1)/2), int(np.ceil(1.*(filt_size-1)/2)), int(1.*(filt_size-1)/2), int(np.ceil(1.*(filt_size-1)/2))] - self.pad_sizes = [pad_size+pad_off for pad_size in self.pad_sizes] - self.stride = stride - self.off = int((self.stride-1)/2.) - self.channels = channels - if(self.filt_size==1): - a = np.array([1.,]) - elif(self.filt_size==2): - a = np.array([1., 1.]) - elif(self.filt_size==3): - a = np.array([1., 2., 1.]) - elif(self.filt_size==4): - a = np.array([1., 3., 3., 1.]) - elif(self.filt_size==5): - a = np.array([1., 4., 6., 4., 1.]) - elif(self.filt_size==6): - a = np.array([1., 5., 10., 10., 5., 1.]) - elif(self.filt_size==7): - a = np.array([1., 6., 15., 20., 15., 6., 1.]) +class ZeroPad1d(torch.nn.modules.padding.ConstantPad1d): + def __init__(self, padding): + super(ZeroPad1d, self).__init__(padding, 0.) - filt = torch.Tensor(a[:,None]*a[None,:]) - filt = filt/torch.sum(filt) - self.register_buffer('filt', filt[None,None,:,:].repeat((self.channels,1,1,1))) - self.pad = get_pad_layer(pad_type)(self.pad_sizes) +class ZeroPad3d(torch.nn.modules.padding.ConstantPad3d): + def __init__(self, padding): + super(ZeroPad3d, self).__init__(padding, 0.) - def forward(self, inp): - if(self.filt_size==1): - if(self.pad_off==0): - return inp[:,:,::self.stride,::self.stride] - else: - return self.pad(inp)[:,:,::self.stride,::self.stride] - else: - return F.conv2d(self.pad(inp), self.filt, stride=self.stride, groups=inp.shape[1]) - -def get_pad_layer(pad_type): - if(pad_type in ['refl','reflect']): - PadLayer = nn.ReflectionPad2d - elif(pad_type in ['repl','replicate']): - PadLayer = nn.ReplicationPad2d - elif(pad_type=='zero'): - PadLayer = nn.ZeroPad2d - else: - print('Pad type [%s] not recognized'%pad_type) - return PadLayer - -class BlurPool1D(nn.Module): - def __init__(self, channels, pad_type='reflect', filt_size=3, stride=2, pad_off=0): - super(BlurPool1D, self).__init__() + +def get_padding_layer(pad_type, dim): + pad_layer_dict = { + 'reflect': { + 1: nn.ReflectionPad1d, + 2: nn.ReflectionPad2d, + }, + 'replicate': { + 1: nn.ReplicationPad1d, + 2: nn.ReplicationPad2d, + 3: nn.ReplicationPad3d + }, + 'zero': { + 1: ZeroPad1d, + 2: nn.ZeroPad2d, + 3: ZeroPad3d + } + } + pad_layer_dict['refl'] = pad_layer_dict['reflect'] + pad_layer_dict['repl'] = pad_layer_dict['replicate'] + try: + return pad_layer_dict[pad_type][dim] + except KeyError: + raise NotImplementedError + + +class BlurPoolND(nn.Module): + + def __init__(self, channels, pad_type='reflect', filt_size=4, stride=2, pad_off=0, dims=2): + + super(BlurPoolND, self).__init__() self.filt_size = filt_size self.pad_off = pad_off - self.pad_sizes = [int(1. * (filt_size - 1) / 2), int(np.ceil(1. * (filt_size - 1) / 2))] + self.pad_sizes = [int(1. * (filt_size - 1) / 2), int(np.ceil(1. * (filt_size - 1) / 2))] * dims self.pad_sizes = [pad_size + pad_off for pad_size in self.pad_sizes] self.stride = stride self.off = int((self.stride - 1) / 2.) self.channels = channels - # print('Filter size [%i]' % filt_size) - if(self.filt_size == 1): + a = None + if self.filt_size == 1: a = np.array([1., ]) - elif(self.filt_size == 2): + elif self.filt_size == 2: a = np.array([1., 1.]) - elif(self.filt_size == 3): + elif self.filt_size == 3: a = np.array([1., 2., 1.]) - elif(self.filt_size == 4): + elif self.filt_size == 4: a = np.array([1., 3., 3., 1.]) - elif(self.filt_size == 5): + elif self.filt_size == 5: a = np.array([1., 4., 6., 4., 1.]) - elif(self.filt_size == 6): + elif self.filt_size == 6: a = np.array([1., 5., 10., 10., 5., 1.]) - elif(self.filt_size == 7): + elif self.filt_size == 7: a = np.array([1., 6., 15., 20., 15., 6., 1.]) - filt = torch.Tensor(a) + filt = create_filter(a, dims) filt = filt / torch.sum(filt) - self.register_buffer('filt', filt[None, None, :].repeat((self.channels, 1, 1))) - - self.pad = get_pad_layer_1d(pad_type)(self.pad_sizes) + self.register_buffer('filt', get_filter_for_registration(filt, self.channels, dims)) + self.pad = get_padding_layer(pad_type, dims)(self.pad_sizes) + self.identity_inference = get_identity_inference_func(dims) + self.conv_func = get_conv_func(dims) def forward(self, inp): - if(self.filt_size == 1): - if(self.pad_off == 0): - return inp[:, :, ::self.stride] + if self.filt_size == 1: + if self.pad_off == 0: + return self.identity_inference(inp, self.stride) else: - return self.pad(inp)[:, :, ::self.stride] + return self.identity_inference(self.pad(inp), self.stride) else: - return F.conv1d(self.pad(inp), self.filt, stride=self.stride, groups=inp.shape[1]) - -def get_pad_layer_1d(pad_type): - if(pad_type in ['refl', 'reflect']): - PadLayer = nn.ReflectionPad1d - elif(pad_type in ['repl', 'replicate']): - PadLayer = nn.ReplicationPad1d - elif(pad_type == 'zero'): - PadLayer = nn.ZeroPad1d - else: - print('Pad type [%s] not recognized' % pad_type) - return PadLayer + return self.conv_func(self.pad(inp), self.filt, stride=self.stride, groups=inp.shape[1]) + + +def create_filter(a, dims): + if dims == 1: + return torch.Tensor(a) + elif dims == 2: + return torch.Tensor(a[:, None] * a[None, :]) + elif dims == 3: + b = a[:, None] * a[None, :] + return torch.Tensor(a[:, None, None] * b[None, :, :]) + return NotImplementedError + + +def get_conv_func(dims): + if dims == 1: + return F.conv1d + elif dims == 2: + return F.conv2d + elif dims == 3: + return F.conv3d + return NotImplementedError + + +def get_filter_for_registration(filt, channels, dims): + if dims == 1: + return filt[None, None, :].repeat((channels, 1, 1)) + elif dims == 2: + return filt[None, None, :, :].repeat((channels, 1, 1, 1)) + elif dims == 3: + return filt[None, None, :, :, :].repeat((channels, 1, 1, 1, 1)) + return NotImplementedError + + +def get_identity_inference_func(dims): + if dims == 1: + return lambda x, s: x[:, :, ::s] + elif dims == 2: + return lambda x, s: x[:, :, ::s, ::s] + elif dims == 3: + return lambda x, s: x[:, :, ::s, ::s, ::s] + return NotImplementedError + + +BlurPool1D = partial(BlurPoolND, dims=1) +BlurPool = BlurPoolND +BlurPool3D = partial(BlurPoolND, dims=3)