Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix gram-matrix detector #228

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
218 changes: 114 additions & 104 deletions openood/postprocessors/gram_postprocessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@

from .base_postprocessor import BasePostprocessor
from .info import num_classes_dict

from collections import defaultdict, Counter
import random

class GRAMPostprocessor(BasePostprocessor):
def __init__(self, config):
Expand All @@ -24,17 +25,21 @@ def __init__(self, config):
self.setup_flag = False

def setup(self, net: nn.Module, id_loader_dict, ood_loader_dict):
net = FeatureExtractor(net)
if not self.setup_flag:
self.feature_min, self.feature_max = sample_estimator(
self.feature_min, self.feature_max, self.normalize_factors = sample_estimator(
net, id_loader_dict['train'], self.num_classes, self.powers)
self.setup_flag = True
else:
pass
net.destroy_hooks()

def postprocess(self, net: nn.Module, data: Any):
net = FeatureExtractor(net)
preds, deviations = get_deviations(net, data, self.feature_min,
self.feature_max, self.num_classes,
self.feature_max, self.normalize_factors,
self.powers)
net.destroy_hooks()
return preds, deviations

def set_hyperparam(self, hyperparam: list):
Expand All @@ -47,121 +52,126 @@ def get_hyperparam(self):
def tensor2list(x):
return x.data.cuda().tolist()

def G_p(ob, p):
temp = ob.detach()

temp = temp**p
temp = temp.reshape(temp.shape[0],temp.shape[1],-1)
temp = ((torch.matmul(temp,temp.transpose(dim0=2,dim1=1)))).sum(dim=2)
temp = (temp.sign()*torch.abs(temp)**(1/p)).reshape(temp.shape[0],-1)

return temp

def delta(mins, maxs, x):
dev = (F.relu(mins-x)/torch.abs(mins+10**-6)).sum(dim=1,keepdim=True)
dev += (F.relu(x-maxs)/torch.abs(maxs+10**-6)).sum(dim=1,keepdim=True)
return dev

class FeatureExtractor(torch.nn.Module):
# Inspired from https://github.com/paaatcha/gram-ood
def __init__(self, torch_model):
super().__init__()
self.torch_model = torch_model
self.feat_list = list()
def _hook_fn(_, input, output):
self.feat_list.append(output)

# To set a different layer, you must use this function:
def hook_layers(torch_model):
hooked_layers = list()
for layer in torch_model.modules():
if isinstance(layer, nn.ReLU) or isinstance(layer, nn.Conv2d):
hooked_layers.append(layer)
return hooked_layers

def register_layers(layers):
regs_layers = list()
for lay in layers:
regs_layers.append(lay.register_forward_hook(_hook_fn))
return regs_layers

## Setting the hook
hl = hook_layers (torch_model)
self.rgl = register_layers (hl)
# print(f"{len(self.rgl)} Features")

def forward(self, x, return_feature_list=True):
preds = self.torch_model(x)
list = self.feat_list.copy()
self.feat_list.clear()
return preds, list

def destroy_hooks(self):
for lay in self.rgl:
lay.remove()

@torch.no_grad()
def sample_estimator(model, train_loader, num_classes, powers):

model.eval()

num_layer = 5 # 4 for lenet
num_poles_list = powers
num_poles = len(num_poles_list)
feature_class = [[[None for x in range(num_poles)]
for y in range(num_layer)] for z in range(num_classes)]
label_list = []
mins = [[[None for x in range(num_poles)] for y in range(num_layer)]
for z in range(num_classes)]
maxs = [[[None for x in range(num_poles)] for y in range(num_layer)]
for z in range(num_classes)]

gram_features = defaultdict(lambda: defaultdict(lambda: defaultdict(lambda : None)))
mins = dict()
maxs = dict()
class_counts = Counter()
# collect features and compute gram metrix
for batch in tqdm(train_loader, desc='Compute min/max'):
data = batch['data'].cuda()
label = batch['label']
_, feature_list = model(data, return_feature_list=True)
label_list = tensor2list(label)
for layer_idx in range(num_layer):

for pole_idx, p in enumerate(num_poles_list):
temp = feature_list[layer_idx].detach()

temp = temp**p
temp = temp.reshape(temp.shape[0], temp.shape[1], -1)
temp = ((torch.matmul(temp,
temp.transpose(dim0=2,
dim1=1)))).sum(dim=2)
temp = (temp.sign() * torch.abs(temp)**(1 / p)).reshape(
temp.shape[0], -1)

temp = tensor2list(temp)
for feature, label in zip(temp, label_list):
if isinstance(feature_class[label][layer_idx][pole_idx],
type(None)):
feature_class[label][layer_idx][pole_idx] = feature
class_counts.update(Counter(label.cpu().numpy()))
for layer_idx, feature in enumerate(feature_list):
for power in powers:
gram_feature = G_p(feature, power).cpu()
for class_ in range(num_classes):
if gram_features[layer_idx][power][class_] is None:
gram_features[layer_idx][power][class_] = gram_feature[label==class_]
else:
feature_class[label][layer_idx][pole_idx].extend(
feature)
gram_features[layer_idx][power][class_] = torch.cat([gram_features[layer_idx][power][class_],gram_feature[label==class_]],dim=0)

val_idxs = {}
train_idxs = {}
for c in class_counts:
L = class_counts[c]
val_idxs[c] = random.sample(range(L),int(0.1*L))
train_idxs[c] = list(set(range(L)) - set(val_idxs[c]))
normalize_factors = []
# compute mins/maxs
for label in range(num_classes):
for layer_idx in range(num_layer):
for poles_idx in range(num_poles):
feature = torch.tensor(
np.array(feature_class[label][layer_idx][poles_idx]))
current_min = feature.min(dim=0, keepdim=True)[0]
current_max = feature.max(dim=0, keepdim=True)[0]

if mins[label][layer_idx][poles_idx] is None:
mins[label][layer_idx][poles_idx] = current_min
maxs[label][layer_idx][poles_idx] = current_max
else:
mins[label][layer_idx][poles_idx] = torch.min(
current_min, mins[label][layer_idx][poles_idx])
maxs[label][layer_idx][poles_idx] = torch.max(
current_min, maxs[label][layer_idx][poles_idx])

return mins, maxs


def get_deviations(model, data, mins, maxs, num_classes, powers):
for layer_idx in gram_features:
total_delta = None
for class_ in class_counts:
trn = train_idxs[class_]
val = val_idxs[class_]
class_deltas = 0
for power in powers:
mins[layer_idx,power,class_] = gram_features[layer_idx][power][class_][trn].min(dim=0,keepdim=True)[0]
maxs[layer_idx,power,class_] = gram_features[layer_idx][power][class_][trn].max(dim=0,keepdim=True)[0]
class_deltas += delta(mins[layer_idx,power,class_],
maxs[layer_idx,power,class_],
gram_features[layer_idx][power][class_][val])
if total_delta is None:
total_delta = class_deltas
else:
total_delta = torch.cat([total_delta,class_deltas],dim=0)
normalize_factors.append(total_delta.mean(dim=0,keepdim=True))
normalize_factors = torch.cat(normalize_factors,dim=1)
return mins, maxs, normalize_factors

def get_deviations(model, data, mins, maxs, normalize_factors, powers):
model.eval()

num_layer = 5 # 4 for lenet
num_poles_list = powers
exist = 1
pred_list = []
dev = [0 for x in range(data.shape[0])]
deviations = torch.zeros(data.shape[0],1)

# get predictions
logits, feature_list = model(data, return_feature_list=True)
confs = F.softmax(logits, dim=1).cpu().detach().numpy()
preds = np.argmax(confs, axis=1)
predsList = preds.tolist()
preds = torch.tensor(preds)

for pred in predsList:
exist = 1
if len(pred_list) == 0:
pred_list.extend([pred])
else:
for pred_now in pred_list:
if pred_now == pred:
exist = 0
if exist == 1:
pred_list.extend([pred])

# compute sample level deviation
for layer_idx in range(num_layer):
for pole_idx, p in enumerate(num_poles_list):
# get gram metirx
temp = feature_list[layer_idx].detach()
temp = temp**p
temp = temp.reshape(temp.shape[0], temp.shape[1], -1)
temp = ((torch.matmul(temp, temp.transpose(dim0=2,
dim1=1)))).sum(dim=2)
temp = (temp.sign() * torch.abs(temp)**(1 / p)).reshape(
temp.shape[0], -1)
temp = tensor2list(temp)

# compute the deviations with train data
for idx in range(len(temp)):
dev[idx] += (F.relu(mins[preds[idx]][layer_idx][pole_idx] -
sum(temp[idx])) /
torch.abs(mins[preds[idx]][layer_idx][pole_idx] +
10**-6)).sum()
dev[idx] += (F.relu(
sum(temp[idx]) - maxs[preds[idx]][layer_idx][pole_idx]) /
torch.abs(maxs[preds[idx]][layer_idx][pole_idx] +
10**-6)).sum()
conf = [i / 50 for i in dev]

return preds, torch.tensor(conf)
confs = F.softmax(logits, dim=1).cpu().detach()
confs, preds = confs.max(dim=1)
for layer_idx, feature in enumerate(feature_list):
n = normalize_factors[:,layer_idx].item()
for power in powers:
gram_feature = G_p(feature, power).cpu()
for class_ in range(logits.shape[1]):
deviations[preds==class_] += delta(mins[layer_idx,power,class_],
maxs[layer_idx,power,class_],
gram_feature[preds==class_])/n

return preds, -deviations/confs[:,None]