diff --git a/ibl/evaluators.py b/ibl/evaluators.py index a4e2ee5..f78f1c1 100644 --- a/ibl/evaluators.py +++ b/ibl/evaluators.py @@ -22,11 +22,15 @@ def extract_cnn_feature(model, inputs, vlad=True, gpu=None): model.eval() inputs = to_torch(inputs).cuda(gpu) - x_pool, x_vlad = model(inputs) - if vlad: - outputs = F.normalize(x_vlad, p=2, dim=-1) + outputs = model(inputs) + if (isinstance(outputs, list) or isinstance(outputs, tuple)): + x_pool, x_vlad = outputs + if vlad: + outputs = F.normalize(x_vlad, p=2, dim=-1) + else: + outputs = F.normalize(x_pool, p=2, dim=-1) else: - outputs = F.normalize(x_pool, p=2, dim=-1) + outputs = F.normalize(outputs, p=2, dim=-1) return outputs def extract_features(model, data_loader, dataset, print_freq=10, diff --git a/ibl/models/__init__.py b/ibl/models/__init__.py index 7e41c43..40a0aab 100644 --- a/ibl/models/__init__.py +++ b/ibl/models/__init__.py @@ -8,6 +8,7 @@ 'vgg16': vgg16, 'netvlad': NetVLAD, 'embednet': EmbedNet, + 'embednetpca': EmbedNetPCA, 'embedregionnet': EmbedRegionNet, } diff --git a/ibl/models/netvlad.py b/ibl/models/netvlad.py index 483ec5d..9dff22c 100644 --- a/ibl/models/netvlad.py +++ b/ibl/models/netvlad.py @@ -81,6 +81,34 @@ def forward(self, x): return pool_x, vlad_x +class EmbedNetPCA(nn.Module): + def __init__(self, base_model, net_vlad, dim=4096): + super(EmbedNetPCA, self).__init__() + self.base_model = base_model + self.net_vlad = net_vlad + self.pca_layer = nn.Conv2d(net_vlad.num_clusters*net_vlad.dim, dim, 1, stride=1, padding=0) + + def _init_params(self): + self.base_model._init_params() + self.net_vlad._init_params() + + def forward(self, x): + _, x = self.base_model(x) + vlad_x = self.net_vlad(x) + + # [IMPORTANT] normalize + vlad_x = F.normalize(vlad_x, p=2, dim=2) # intra-normalization + vlad_x = vlad_x.view(x.size(0), -1) # flatten + vlad_x = F.normalize(vlad_x, p=2, dim=1) # L2 normalize + + # reduction + N, D = vlad_x.size() + vlad_x = vlad_x.view(N, D, 1, 1) + vlad_x = self.pca_layer(vlad_x).view(N, -1) + vlad_x = F.normalize(vlad_x, p=2, dim=-1) # L2 normalize + + return vlad_x + class EmbedRegionNet(nn.Module): def __init__(self, base_model, net_vlad, tuple_size=1): super(EmbedRegionNet, self).__init__() diff --git a/ibl/utils/osutils.py b/ibl/utils/osutils.py index af3417d..9cb6f93 100644 --- a/ibl/utils/osutils.py +++ b/ibl/utils/osutils.py @@ -4,6 +4,7 @@ def mkdir_if_missing(dir_path): + if not dir_path: return try: os.makedirs(dir_path) except OSError as e: