-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
3763561
commit a3b301e
Showing
26 changed files
with
830 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
# Copyright (c) Microsoft Corporation. | ||
# Licensed under the MIT License. | ||
|
||
from .srflexmatch import FlexMatch | ||
from .utils import FlexMatchThresholdingHook |
Binary file added
BIN
+263 Bytes
semilearn/algorithms/srflexmatch/__pycache__/__init__.cpython-39.pyc
Binary file not shown.
Binary file added
BIN
+5.76 KB
semilearn/algorithms/srflexmatch/__pycache__/flexmatch.cpython-39.pyc
Binary file not shown.
Binary file not shown.
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,152 @@ | ||
|
||
# Copyright (c) Microsoft Corporation. | ||
# Licensed under the MIT License. | ||
import os | ||
import torch | ||
import numpy as np | ||
from .utils import FlexMatchThresholdingHook | ||
from semilearn.core import AlgorithmBase | ||
from semilearn.core.utils import ALGORITHMS | ||
from semilearn.algorithms.hooks import PseudoLabelingHook | ||
from semilearn.algorithms.utils import SSL_Argument, str2bool | ||
from semilearn.algorithms.SemiReward import Rewarder,Generator,SemiReward_infer,SemiReward_train | ||
CUDA_LAUNCH_BLOCKING=1 | ||
@ALGORITHMS.register('srflexmatch') | ||
class FlexMatch(AlgorithmBase): | ||
""" | ||
FlexMatch algorithm (https://arxiv.org/abs/2110.08263). | ||
Args: | ||
- args (`argparse`): | ||
algorithm arguments | ||
- net_builder (`callable`): | ||
network loading function | ||
- tb_log (`TBLog`): | ||
tensorboard logger | ||
- logger (`logging.Logger`): | ||
logger to use | ||
- T (`float`): | ||
Temperature for pseudo-label sharpening | ||
- p_cutoff(`float`): | ||
Confidence threshold for generating pseudo-labels | ||
- hard_label (`bool`, *optional*, default to `False`): | ||
If True, targets have [Batch size] shape with int values. If False, the target is vector | ||
- ulb_dest_len (`int`): | ||
Length of unlabeled data | ||
- thresh_warmup (`bool`, *optional*, default to `True`): | ||
If True, warmup the confidence threshold, so that at the beginning of the training, all estimated | ||
learning effects gradually rise from 0 until the number of unused unlabeled data is no longer | ||
predominant | ||
""" | ||
def __init__(self, args, net_builder, tb_log=None, logger=None): | ||
super().__init__(args, net_builder, tb_log, logger) | ||
# flexmatch specified arguments | ||
self.init(T=args.T, p_cutoff=args.p_cutoff, hard_label=args.hard_label, thresh_warmup=args.thresh_warmup) | ||
self.it=0 | ||
self.rewarder = Rewarder(128,384).cuda(self.gpu) | ||
self.generator = Generator(384).cuda(self.gpu) | ||
self.starttiming=20000 | ||
self.rewarder_optimizer = torch.optim.Adam(self.rewarder.parameters(), lr=0.0005) | ||
self.generator_optimizer = torch.optim.Adam(self.generator.parameters(), lr=0.0005) | ||
self.criterion = torch.nn.MSELoss() | ||
|
||
self.semi_reward_infer = SemiReward_infer(self.rewarder, self.starttiming) | ||
self.semi_reward_train = SemiReward_train(self.rewarder, self.generator, self.criterion, self.starttiming,self.gpu) | ||
def init(self, T, p_cutoff, hard_label=True, thresh_warmup=True): | ||
self.T = T | ||
self.p_cutoff = p_cutoff | ||
self.use_hard_label = hard_label | ||
self.thresh_warmup = thresh_warmup | ||
|
||
def set_hooks(self): | ||
self.register_hook(PseudoLabelingHook(), "PseudoLabelingHook") | ||
self.register_hook(FlexMatchThresholdingHook(ulb_dest_len=self.args.ulb_dest_len, num_classes=self.num_classes, thresh_warmup=self.args.thresh_warmup), "MaskingHook") | ||
super().set_hooks() | ||
|
||
def train_step(self, x_lb, y_lb, idx_ulb, x_ulb_w, x_ulb_s): | ||
num_lb = y_lb.shape[0] | ||
|
||
# inference and calculate sup/unsup losses | ||
with self.amp_cm(): | ||
if self.use_cat: | ||
inputs = torch.cat((x_lb, x_ulb_w, x_ulb_s)) | ||
outputs = self.model(inputs) | ||
logits_x_lb = outputs['logits'][:num_lb] | ||
logits_x_ulb_w, logits_x_ulb_s = outputs['logits'][num_lb:].chunk(2) | ||
feats_x_lb = outputs['feat'][:num_lb] | ||
feats_x_ulb_w, feats_x_ulb_s = outputs['feat'][num_lb:].chunk(2) | ||
else: | ||
outs_x_lb = self.model(x_lb) | ||
logits_x_lb = outs_x_lb['logits'] | ||
feats_x_lb = outs_x_lb['feat'] | ||
outs_x_ulb_s = self.model(x_ulb_s) | ||
logits_x_ulb_s = outs_x_ulb_s['logits'] | ||
feats_x_ulb_s = outs_x_ulb_s['feat'] | ||
with torch.no_grad(): | ||
outs_x_ulb_w = self.model(x_ulb_w) | ||
logits_x_ulb_w = outs_x_ulb_w['logits'] | ||
feats_x_ulb_w = outs_x_ulb_w['feat'] | ||
feat_dict = {'x_lb':feats_x_lb, 'x_ulb_w':feats_x_ulb_w, 'x_ulb_s':feats_x_ulb_s} | ||
|
||
sup_loss = self.ce_loss(logits_x_lb, y_lb, reduction='mean') | ||
|
||
# probs_x_ulb_w = torch.softmax(logits_x_ulb_w, dim=-1) | ||
probs_x_ulb_w = self.compute_prob(logits_x_ulb_w.detach()) | ||
|
||
if self.registered_hook("DistAlignHook"): | ||
probs_x_ulb_w = self.call_hook("dist_align", "DistAlignHook", probs_x_ulb=probs_x_ulb_w.detach()) | ||
mask = self.call_hook("masking", "MaskingHook", logits_x_ulb=probs_x_ulb_w, softmax_x_ulb=False, idx_ulb=idx_ulb) | ||
pseudo_label = self.call_hook("gen_ulb_targets", "PseudoLabelingHook", | ||
logits=probs_x_ulb_w, | ||
use_hard_label=self.use_hard_label, | ||
T=self.T, | ||
softmax=False) | ||
|
||
filtered_pseudo_labels = self.semi_reward_infer(feats_x_ulb_w, pseudo_label, self.it) | ||
for filtered_pseudo_label in filtered_pseudo_labels: | ||
unsup_loss = self.consistency_loss(logits_x_ulb_s, | ||
filtered_pseudo_label, | ||
'ce', | ||
mask=mask) | ||
|
||
if self.it > 0: | ||
generator_loss, rewarder_loss = self.semi_reward_train(feats_x_ulb_w, pseudo_label, y_lb, self.it) | ||
self.generator_optimizer.zero_grad() | ||
self.rewarder_optimizer.zero_grad() | ||
generator_loss.backward(retain_graph=True) | ||
rewarder_loss.backward(retain_graph=True) | ||
self.generator_optimizer.step() | ||
self.rewarder_optimizer.step() | ||
total_loss = sup_loss + self.lambda_u * unsup_loss | ||
|
||
out_dict = self.process_out_dict(loss=total_loss, feat=feat_dict) | ||
log_dict = self.process_log_dict(sup_loss=sup_loss.item(), | ||
unsup_loss=unsup_loss.item(), | ||
total_loss=total_loss.item(), | ||
util_ratio=mask.float().mean().item()) | ||
return out_dict, log_dict | ||
|
||
|
||
def get_save_dict(self): | ||
save_dict = super().get_save_dict() | ||
# additional saving arguments | ||
save_dict['classwise_acc'] = self.hooks_dict['MaskingHook'].classwise_acc.cpu() | ||
save_dict['selected_label'] = self.hooks_dict['MaskingHook'].selected_label.cpu() | ||
return save_dict | ||
|
||
def load_model(self, load_path): | ||
checkpoint = super().load_model(load_path) | ||
self.hooks_dict['MaskingHook'].classwise_acc = checkpoint['classwise_acc'].cuda(self.gpu) | ||
self.hooks_dict['MaskingHook'].selected_label = checkpoint['selected_label'].cuda(self.gpu) | ||
self.print_fn("additional parameter loaded") | ||
return checkpoint | ||
|
||
@staticmethod | ||
def get_argument(): | ||
return [ | ||
SSL_Argument('--hard_label', str2bool, True), | ||
SSL_Argument('--T', float, 0.5), | ||
SSL_Argument('--p_cutoff', float, 0.95), | ||
SSL_Argument('--thresh_warmup', str2bool, True), | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,67 @@ | ||
# Copyright (c) Microsoft Corporation. | ||
# Licensed under the MIT License. | ||
|
||
import torch | ||
from copy import deepcopy | ||
from collections import Counter | ||
|
||
from semilearn.algorithms.hooks import MaskingHook | ||
|
||
|
||
class FlexMatchThresholdingHook(MaskingHook): | ||
""" | ||
Adaptive Thresholding in FlexMatch | ||
""" | ||
def __init__(self, ulb_dest_len, num_classes, thresh_warmup=True, *args, **kwargs): | ||
super().__init__(*args, **kwargs) | ||
self.ulb_dest_len = ulb_dest_len | ||
self.num_classes = num_classes | ||
self.thresh_warmup = thresh_warmup | ||
self.selected_label = torch.ones((self.ulb_dest_len,), dtype=torch.long, ) * -1 | ||
self.classwise_acc = torch.zeros((self.num_classes,)) | ||
|
||
@torch.no_grad() | ||
def update(self, *args, **kwargs): | ||
pseudo_counter = Counter(self.selected_label.tolist()) | ||
if max(pseudo_counter.values()) < self.ulb_dest_len: # not all(5w) -1 | ||
if self.thresh_warmup: | ||
for i in range(self.num_classes): | ||
self.classwise_acc[i] = pseudo_counter[i] / max(pseudo_counter.values()) | ||
else: | ||
wo_negative_one = deepcopy(pseudo_counter) | ||
if -1 in wo_negative_one.keys(): | ||
wo_negative_one.pop(-1) | ||
for i in range(self.num_classes): | ||
self.classwise_acc[i] = pseudo_counter[i] / max(wo_negative_one.values()) | ||
|
||
@torch.no_grad() | ||
def masking(self, algorithm, logits_x_ulb, idx_ulb, softmax_x_ulb=True, *args, **kwargs): | ||
if not self.selected_label.is_cuda: | ||
self.selected_label = self.selected_label.to(logits_x_ulb.device) | ||
if not self.classwise_acc.is_cuda: | ||
self.classwise_acc = self.classwise_acc.to(logits_x_ulb.device) | ||
|
||
if softmax_x_ulb: | ||
# probs_x_ulb = torch.softmax(logits_x_ulb.detach(), dim=-1) | ||
probs_x_ulb = self.compute_prob(logits_x_ulb.detach()) | ||
else: | ||
# logits is already probs | ||
probs_x_ulb = logits_x_ulb.detach() | ||
max_probs, max_idx = torch.max(probs_x_ulb, dim=-1) | ||
# mask = max_probs.ge(p_cutoff * (class_acc[max_idx] + 1.) / 2).float() # linear | ||
# mask = max_probs.ge(p_cutoff * (1 / (2. - class_acc[max_idx]))).float() # low_limit | ||
mask = max_probs.ge(algorithm.p_cutoff * (self.classwise_acc[max_idx] / (2. - self.classwise_acc[max_idx]))) # convex | ||
# mask = max_probs.ge(p_cutoff * (torch.log(class_acc[max_idx] + 1.) + 0.5)/(math.log(2) + 0.5)).float() # concave | ||
select = max_probs.ge(algorithm.p_cutoff) | ||
mask = mask.to(max_probs.dtype) | ||
|
||
# update | ||
if idx_ulb[select == 1].nelement() != 0: | ||
self.selected_label[idx_ulb[select == 1]] = max_idx[select == 1] | ||
self.update() | ||
|
||
return mask | ||
|
||
|
||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
# Copyright (c) Microsoft Corporation. | ||
# Licensed under the MIT License. | ||
|
||
from .srfreematch import FreeMatch | ||
from .utils import FreeMatchThresholingHook |
Binary file added
BIN
+262 Bytes
semilearn/algorithms/srfreematch/__pycache__/__init__.cpython-39.pyc
Binary file not shown.
Binary file added
BIN
+6.62 KB
semilearn/algorithms/srfreematch/__pycache__/freematch.cpython-39.pyc
Binary file not shown.
Binary file not shown.
Binary file not shown.
Oops, something went wrong.