Skip to content

Commit

Permalink
Add files via upload
Browse files Browse the repository at this point in the history
  • Loading branch information
WayneJin0918 authored Oct 3, 2023
1 parent 3763561 commit a3b301e
Show file tree
Hide file tree
Showing 26 changed files with 830 additions and 0 deletions.
5 changes: 5 additions & 0 deletions semilearn/algorithms/srflexmatch/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.

from .srflexmatch import FlexMatch
from .utils import FlexMatchThresholdingHook
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
152 changes: 152 additions & 0 deletions semilearn/algorithms/srflexmatch/srflexmatch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,152 @@

# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import os
import torch
import numpy as np
from .utils import FlexMatchThresholdingHook
from semilearn.core import AlgorithmBase
from semilearn.core.utils import ALGORITHMS
from semilearn.algorithms.hooks import PseudoLabelingHook
from semilearn.algorithms.utils import SSL_Argument, str2bool
from semilearn.algorithms.SemiReward import Rewarder,Generator,SemiReward_infer,SemiReward_train
CUDA_LAUNCH_BLOCKING=1
@ALGORITHMS.register('srflexmatch')
class FlexMatch(AlgorithmBase):
"""
FlexMatch algorithm (https://arxiv.org/abs/2110.08263).
Args:
- args (`argparse`):
algorithm arguments
- net_builder (`callable`):
network loading function
- tb_log (`TBLog`):
tensorboard logger
- logger (`logging.Logger`):
logger to use
- T (`float`):
Temperature for pseudo-label sharpening
- p_cutoff(`float`):
Confidence threshold for generating pseudo-labels
- hard_label (`bool`, *optional*, default to `False`):
If True, targets have [Batch size] shape with int values. If False, the target is vector
- ulb_dest_len (`int`):
Length of unlabeled data
- thresh_warmup (`bool`, *optional*, default to `True`):
If True, warmup the confidence threshold, so that at the beginning of the training, all estimated
learning effects gradually rise from 0 until the number of unused unlabeled data is no longer
predominant
"""
def __init__(self, args, net_builder, tb_log=None, logger=None):
super().__init__(args, net_builder, tb_log, logger)
# flexmatch specified arguments
self.init(T=args.T, p_cutoff=args.p_cutoff, hard_label=args.hard_label, thresh_warmup=args.thresh_warmup)
self.it=0
self.rewarder = Rewarder(128,384).cuda(self.gpu)
self.generator = Generator(384).cuda(self.gpu)
self.starttiming=20000
self.rewarder_optimizer = torch.optim.Adam(self.rewarder.parameters(), lr=0.0005)
self.generator_optimizer = torch.optim.Adam(self.generator.parameters(), lr=0.0005)
self.criterion = torch.nn.MSELoss()

self.semi_reward_infer = SemiReward_infer(self.rewarder, self.starttiming)
self.semi_reward_train = SemiReward_train(self.rewarder, self.generator, self.criterion, self.starttiming,self.gpu)
def init(self, T, p_cutoff, hard_label=True, thresh_warmup=True):
self.T = T
self.p_cutoff = p_cutoff
self.use_hard_label = hard_label
self.thresh_warmup = thresh_warmup

def set_hooks(self):
self.register_hook(PseudoLabelingHook(), "PseudoLabelingHook")
self.register_hook(FlexMatchThresholdingHook(ulb_dest_len=self.args.ulb_dest_len, num_classes=self.num_classes, thresh_warmup=self.args.thresh_warmup), "MaskingHook")
super().set_hooks()

def train_step(self, x_lb, y_lb, idx_ulb, x_ulb_w, x_ulb_s):
num_lb = y_lb.shape[0]

# inference and calculate sup/unsup losses
with self.amp_cm():
if self.use_cat:
inputs = torch.cat((x_lb, x_ulb_w, x_ulb_s))
outputs = self.model(inputs)
logits_x_lb = outputs['logits'][:num_lb]
logits_x_ulb_w, logits_x_ulb_s = outputs['logits'][num_lb:].chunk(2)
feats_x_lb = outputs['feat'][:num_lb]
feats_x_ulb_w, feats_x_ulb_s = outputs['feat'][num_lb:].chunk(2)
else:
outs_x_lb = self.model(x_lb)
logits_x_lb = outs_x_lb['logits']
feats_x_lb = outs_x_lb['feat']
outs_x_ulb_s = self.model(x_ulb_s)
logits_x_ulb_s = outs_x_ulb_s['logits']
feats_x_ulb_s = outs_x_ulb_s['feat']
with torch.no_grad():
outs_x_ulb_w = self.model(x_ulb_w)
logits_x_ulb_w = outs_x_ulb_w['logits']
feats_x_ulb_w = outs_x_ulb_w['feat']
feat_dict = {'x_lb':feats_x_lb, 'x_ulb_w':feats_x_ulb_w, 'x_ulb_s':feats_x_ulb_s}

sup_loss = self.ce_loss(logits_x_lb, y_lb, reduction='mean')

# probs_x_ulb_w = torch.softmax(logits_x_ulb_w, dim=-1)
probs_x_ulb_w = self.compute_prob(logits_x_ulb_w.detach())

if self.registered_hook("DistAlignHook"):
probs_x_ulb_w = self.call_hook("dist_align", "DistAlignHook", probs_x_ulb=probs_x_ulb_w.detach())
mask = self.call_hook("masking", "MaskingHook", logits_x_ulb=probs_x_ulb_w, softmax_x_ulb=False, idx_ulb=idx_ulb)
pseudo_label = self.call_hook("gen_ulb_targets", "PseudoLabelingHook",
logits=probs_x_ulb_w,
use_hard_label=self.use_hard_label,
T=self.T,
softmax=False)

filtered_pseudo_labels = self.semi_reward_infer(feats_x_ulb_w, pseudo_label, self.it)
for filtered_pseudo_label in filtered_pseudo_labels:
unsup_loss = self.consistency_loss(logits_x_ulb_s,
filtered_pseudo_label,
'ce',
mask=mask)

if self.it > 0:
generator_loss, rewarder_loss = self.semi_reward_train(feats_x_ulb_w, pseudo_label, y_lb, self.it)
self.generator_optimizer.zero_grad()
self.rewarder_optimizer.zero_grad()
generator_loss.backward(retain_graph=True)
rewarder_loss.backward(retain_graph=True)
self.generator_optimizer.step()
self.rewarder_optimizer.step()
total_loss = sup_loss + self.lambda_u * unsup_loss

out_dict = self.process_out_dict(loss=total_loss, feat=feat_dict)
log_dict = self.process_log_dict(sup_loss=sup_loss.item(),
unsup_loss=unsup_loss.item(),
total_loss=total_loss.item(),
util_ratio=mask.float().mean().item())
return out_dict, log_dict


def get_save_dict(self):
save_dict = super().get_save_dict()
# additional saving arguments
save_dict['classwise_acc'] = self.hooks_dict['MaskingHook'].classwise_acc.cpu()
save_dict['selected_label'] = self.hooks_dict['MaskingHook'].selected_label.cpu()
return save_dict

def load_model(self, load_path):
checkpoint = super().load_model(load_path)
self.hooks_dict['MaskingHook'].classwise_acc = checkpoint['classwise_acc'].cuda(self.gpu)
self.hooks_dict['MaskingHook'].selected_label = checkpoint['selected_label'].cuda(self.gpu)
self.print_fn("additional parameter loaded")
return checkpoint

@staticmethod
def get_argument():
return [
SSL_Argument('--hard_label', str2bool, True),
SSL_Argument('--T', float, 0.5),
SSL_Argument('--p_cutoff', float, 0.95),
SSL_Argument('--thresh_warmup', str2bool, True),
]
67 changes: 67 additions & 0 deletions semilearn/algorithms/srflexmatch/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.

import torch
from copy import deepcopy
from collections import Counter

from semilearn.algorithms.hooks import MaskingHook


class FlexMatchThresholdingHook(MaskingHook):
"""
Adaptive Thresholding in FlexMatch
"""
def __init__(self, ulb_dest_len, num_classes, thresh_warmup=True, *args, **kwargs):
super().__init__(*args, **kwargs)
self.ulb_dest_len = ulb_dest_len
self.num_classes = num_classes
self.thresh_warmup = thresh_warmup
self.selected_label = torch.ones((self.ulb_dest_len,), dtype=torch.long, ) * -1
self.classwise_acc = torch.zeros((self.num_classes,))

@torch.no_grad()
def update(self, *args, **kwargs):
pseudo_counter = Counter(self.selected_label.tolist())
if max(pseudo_counter.values()) < self.ulb_dest_len: # not all(5w) -1
if self.thresh_warmup:
for i in range(self.num_classes):
self.classwise_acc[i] = pseudo_counter[i] / max(pseudo_counter.values())
else:
wo_negative_one = deepcopy(pseudo_counter)
if -1 in wo_negative_one.keys():
wo_negative_one.pop(-1)
for i in range(self.num_classes):
self.classwise_acc[i] = pseudo_counter[i] / max(wo_negative_one.values())

@torch.no_grad()
def masking(self, algorithm, logits_x_ulb, idx_ulb, softmax_x_ulb=True, *args, **kwargs):
if not self.selected_label.is_cuda:
self.selected_label = self.selected_label.to(logits_x_ulb.device)
if not self.classwise_acc.is_cuda:
self.classwise_acc = self.classwise_acc.to(logits_x_ulb.device)

if softmax_x_ulb:
# probs_x_ulb = torch.softmax(logits_x_ulb.detach(), dim=-1)
probs_x_ulb = self.compute_prob(logits_x_ulb.detach())
else:
# logits is already probs
probs_x_ulb = logits_x_ulb.detach()
max_probs, max_idx = torch.max(probs_x_ulb, dim=-1)
# mask = max_probs.ge(p_cutoff * (class_acc[max_idx] + 1.) / 2).float() # linear
# mask = max_probs.ge(p_cutoff * (1 / (2. - class_acc[max_idx]))).float() # low_limit
mask = max_probs.ge(algorithm.p_cutoff * (self.classwise_acc[max_idx] / (2. - self.classwise_acc[max_idx]))) # convex
# mask = max_probs.ge(p_cutoff * (torch.log(class_acc[max_idx] + 1.) + 0.5)/(math.log(2) + 0.5)).float() # concave
select = max_probs.ge(algorithm.p_cutoff)
mask = mask.to(max_probs.dtype)

# update
if idx_ulb[select == 1].nelement() != 0:
self.selected_label[idx_ulb[select == 1]] = max_idx[select == 1]
self.update()

return mask




5 changes: 5 additions & 0 deletions semilearn/algorithms/srfreematch/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.

from .srfreematch import FreeMatch
from .utils import FreeMatchThresholingHook
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Loading

0 comments on commit a3b301e

Please sign in to comment.