From a3b301e078361f4cef2fd6e688081829e63a269e Mon Sep 17 00:00:00 2001 From: Weiyang Jin <137654456+WayneJin0918@users.noreply.github.com> Date: Tue, 3 Oct 2023 17:54:24 +0800 Subject: [PATCH] Add files via upload --- semilearn/algorithms/srflexmatch/__init__.py | 5 + .../__pycache__/__init__.cpython-39.pyc | Bin 0 -> 263 bytes .../__pycache__/flexmatch.cpython-39.pyc | Bin 0 -> 5894 bytes .../__pycache__/main.cpython-39.pyc | Bin 0 -> 2480 bytes .../__pycache__/utils.cpython-39.pyc | Bin 0 -> 2030 bytes .../algorithms/srflexmatch/srflexmatch.py | 152 ++++++++++++++++ semilearn/algorithms/srflexmatch/utils.py | 67 +++++++ semilearn/algorithms/srfreematch/__init__.py | 5 + .../__pycache__/__init__.cpython-39.pyc | Bin 0 -> 262 bytes .../__pycache__/freematch.cpython-39.pyc | Bin 0 -> 6774 bytes .../__pycache__/main.cpython-39.pyc | Bin 0 -> 2461 bytes .../__pycache__/utils.cpython-39.pyc | Bin 0 -> 2155 bytes .../algorithms/srfreematch/srfreematch.py | 170 ++++++++++++++++++ semilearn/algorithms/srfreematch/utils.py | 69 +++++++ .../algorithms/srpseudolabel/__init__.py | 4 + .../__pycache__/__init__.cpython-39.pyc | Bin 0 -> 214 bytes .../__pycache__/main.cpython-39.pyc | Bin 0 -> 2230 bytes .../__pycache__/pseudolabel.cpython-39.pyc | Bin 0 -> 5263 bytes .../algorithms/srpseudolabel/srpseudolabel.py | 112 ++++++++++++ semilearn/algorithms/srsoftmatch/__init__.py | 5 + .../__pycache__/__init__.cpython-39.pyc | Bin 0 -> 260 bytes .../__pycache__/main.cpython-39.pyc | Bin 0 -> 2461 bytes .../__pycache__/softmatch.cpython-39.pyc | Bin 0 -> 7165 bytes .../__pycache__/utils.cpython-39.pyc | Bin 0 -> 2277 bytes .../algorithms/srsoftmatch/srsoftmatch.py | 164 +++++++++++++++++ semilearn/algorithms/srsoftmatch/utils.py | 77 ++++++++ 26 files changed, 830 insertions(+) create mode 100644 semilearn/algorithms/srflexmatch/__init__.py create mode 100644 semilearn/algorithms/srflexmatch/__pycache__/__init__.cpython-39.pyc create mode 100644 semilearn/algorithms/srflexmatch/__pycache__/flexmatch.cpython-39.pyc create mode 100644 semilearn/algorithms/srflexmatch/__pycache__/main.cpython-39.pyc create mode 100644 semilearn/algorithms/srflexmatch/__pycache__/utils.cpython-39.pyc create mode 100644 semilearn/algorithms/srflexmatch/srflexmatch.py create mode 100644 semilearn/algorithms/srflexmatch/utils.py create mode 100644 semilearn/algorithms/srfreematch/__init__.py create mode 100644 semilearn/algorithms/srfreematch/__pycache__/__init__.cpython-39.pyc create mode 100644 semilearn/algorithms/srfreematch/__pycache__/freematch.cpython-39.pyc create mode 100644 semilearn/algorithms/srfreematch/__pycache__/main.cpython-39.pyc create mode 100644 semilearn/algorithms/srfreematch/__pycache__/utils.cpython-39.pyc create mode 100644 semilearn/algorithms/srfreematch/srfreematch.py create mode 100644 semilearn/algorithms/srfreematch/utils.py create mode 100644 semilearn/algorithms/srpseudolabel/__init__.py create mode 100644 semilearn/algorithms/srpseudolabel/__pycache__/__init__.cpython-39.pyc create mode 100644 semilearn/algorithms/srpseudolabel/__pycache__/main.cpython-39.pyc create mode 100644 semilearn/algorithms/srpseudolabel/__pycache__/pseudolabel.cpython-39.pyc create mode 100644 semilearn/algorithms/srpseudolabel/srpseudolabel.py create mode 100644 semilearn/algorithms/srsoftmatch/__init__.py create mode 100644 semilearn/algorithms/srsoftmatch/__pycache__/__init__.cpython-39.pyc create mode 100644 semilearn/algorithms/srsoftmatch/__pycache__/main.cpython-39.pyc create mode 100644 semilearn/algorithms/srsoftmatch/__pycache__/softmatch.cpython-39.pyc create mode 100644 semilearn/algorithms/srsoftmatch/__pycache__/utils.cpython-39.pyc create mode 100644 semilearn/algorithms/srsoftmatch/srsoftmatch.py create mode 100644 semilearn/algorithms/srsoftmatch/utils.py diff --git a/semilearn/algorithms/srflexmatch/__init__.py b/semilearn/algorithms/srflexmatch/__init__.py new file mode 100644 index 0000000..ff356f2 --- /dev/null +++ b/semilearn/algorithms/srflexmatch/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +from .srflexmatch import FlexMatch +from .utils import FlexMatchThresholdingHook \ No newline at end of file diff --git a/semilearn/algorithms/srflexmatch/__pycache__/__init__.cpython-39.pyc b/semilearn/algorithms/srflexmatch/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b7ca2f119efcb58f907eeca615f95dcc1aa7c435 GIT binary patch literal 263 zcmYjKOHRWu5Vezl+NwPO5*NsV`T|weRaGGY3+Ses$i!|uFt#f@0m?-<11Cbtwkxi{ zs$&*apEU0yJ-v~W$%qlw4?pUQ^pL_iG71a2cul03;>axDaGrdInawvj`TUWU#V@By z$@}3F^~arTYeV7c6rN1b(Rftgnw9sj_hr!!w^X&M8q%ZwqH6(T$jN`Ra7pq{f42gm z`-IwP*NB(*y{O-wWz+~x-&hGwL~Q}J3{FUFe1O*2C{E*|s=z^8Rny(RAIz**9qM^O J^D(_0^CMv*M?3%k literal 0 HcmV?d00001 diff --git a/semilearn/algorithms/srflexmatch/__pycache__/flexmatch.cpython-39.pyc b/semilearn/algorithms/srflexmatch/__pycache__/flexmatch.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..70cc2e7e5ef7ab7232bb2a84254d961a1853ff18 GIT binary patch literal 5894 zcmai2%WvGq87H~i7x$sn+mh{s?IHCVmSj6cnx<(I+esWZvQoqfif$XwTn=|dnUBqo zYuPHd2r37;L`{GKJ+*-Ln17|m9_Lh`CtnKmRul=)-#6Ts6tyLHhBM!M^UcgR-*fF+ z%~0^zy><7`mlfsj)R}$q=)8pz{}B~anCdI7?5nm^+1FY*rm>u_2l-ZBr7_(v1jSZS zw)1`|D7VV8UGR;d(yGXI(XR%zRxPNv>at(*=YmG75zM#dgN4>Yu-IA*mRd^~(^|`J z`DBHahw5F08LaY5VHLM{lsi&et4w{MG^@|CH)ZF!in^(awRe5@$$dL<_ICHU8}CIv z^TO_XQS=CdwOf8S;$E^B+_qyE0OQv7-48yzxBK4x2k0&QJa+pm+P2%SKc1wC>VpT{ z)-B%c2X2^%Qk?Lc?I`k_c~Sb%J+L`*xhUUtLzml0#Koc$#a`%Iu@`v0jm4f?p=dmC z124n2ys(2Avyg<_UMR{P>^z{IJ9MJiM;~=!hK8#U?^_yE!TW{Jm5$1COn+8x>1>e| zSn-+K%1da8m7kSbg^X*hqJ&jh4X~1I*V!D}W!GR0HvdckM=PHzy3(qGt4pFhVWjP0 zeZw3*6R>HInQ30xOOjsv+SZoMpLmZqBi`Mz+ws=T8#k_RUVruGD=#-UjR^s;GJb8+ zBdQyw4VZc5AwE5u$L_=CX#_%=t_Cc57x5V)p_^FkzUSjYK+Cax-){To(gNl|#2=Y{ zWRrB5oqp&fUKE}on6xcF>Vnqp?d_<0-lh^aj3eHTAO#@S?YjI7i3|u5)WA}kAOx)Q zXpK;l$n3}N88W+|(7|q#^Cs@PL63y5&t0<Ly^gud`|dR}!M(dlZ0^~QUGrDBNm0z$J9Iy$Gxl8b0D5hDVPZbE z{k|J-nwTl6V(K1FAg7yNY(91!DC!w@^!>KQ+&Hm(Hw3dm=R8(#yI~irL>;pq%7eMg zWOiboA(nB=f^i7?Jy82UXQ%U>x7y_BBSLfx=bSsaX2ucDXeTCqgw%E+Yjo$dQ+ar@ zFz8={8u=!U>A}#r?9{eN$$Z!5a-Q4ifJ?F2bM&Qp_nF~o~Ufq5N=Pdr}|4*NkH z>LwWiv1MoYLhKGBs782JHE6XScUcs8p&ceu^RVNP*BL+5E}-O4;=iCulq2OtP1RIO zbE!VmRV7vSwN%~Dq39@ilmb(aF~YP|v5HBF?BkJkQU<0mQ~|3bRY24kNR8zIuP1Z+ zjs1DxRZbR$3R%-*qOmxE*G|Eg0AD^?8DbCCYI1?*q12CP&RViQ>eujHB{+(4MC)8+ z3T$j?r>P5dSCq!8D2%Dbh}tP3;^f|*e~QZc6iv_EQIzgCU;p#3_uFs13D>3hym?9F z<322bC|PhaUSe57r!Q|{8lR_XfvQDRLIb=^4c+N8TjaXEeyhYs{)|&tAj%^(k9eyR zLwyoR4P;2({C% zTbAS7aco&%EB}6f%lG<+p0fwN-r7HSy5&4>BQk*ZcQ*p4%oc4zD#g~s|HNApM?GzA z_MY-pZ0{;cyn;$m7qoS?pf0Hks4w8HY7O;pVd6eF$CtGG9f8*{{SFbjgbJdrSPH~l zu~dk>Vri+4IwwU)g$PZrSJa_`)>MY%O*ixW3dl^crb58CXh2risdWL>t0-{^m4a>4 zdsv@tnWV`fNgO}aE}}^Fd>xHcIZ+bz2%+LoBh^F?JKxViH5I5PwF*o-&aoVgCq*c8 zX{b;T2!+shIO%#bC#tjN7&qDvvP*O0OnNK#?f4O7QO0#;QR6UpF{H=ZL!c*#aM}`Y zQgxH6m#HGj=X5)w93xmD<7*V@1sR4}5?U{h>@+$MNx9VU%otHQUzJycz1$L_AJ<6#{*tns@|GJI~rT3R8 zX*I2-b(a5AZm@vsFQw&FlXf*XSj4*J)HpGQ3g|AOS7D{0`jyJJfmvl`R!Wy>lo`i4 z8V9Vxs_7D|r8!nV&JUI)Y>qWj8YMehXEio|To|l?+8kEKcfT>r4OY|Dx0R1C49G^$ zrqTS%AL;;}wh`%3-f_k_jNNBUP9H!yQS&LqKY40UjNeQ*)&#gioxRkW5zt@ZsC zw#1f?RX{J!vP>%q?gkdIvP~z!jAM9((ZXl{G3dR&^xZ3K)duO zmfse|Cl+bh!MG6%J#g(1&ZG>#+?N5dsJ#uny5)P_kW@_=3|-5nKJY3%9<^gj(h$KXEPM zj3_W{4xVyuY@S>uZ4P`LLaNMGO^(Q4qQkDyQVj%oG1;@wc`E7hAJEwMX{-!ynyt#e zNBvUUb{E{*wo#?cBx#w`|Im~<6B<|(R?tHR$l zR0Cxl<&tKoCA{OBzAsT2k*n)f1Qtu^YpRY{`=(H6pjSX2@3%(2s;*$3p;qzIcT=mQ zCVyW4wxAy_O+DFkE2lg=84PlvRAO?5S>QcX28c)Y@#n)DeEEqElcX~3bux+>EI)$f z0HeQW;M2(V(bQh&ggr&rbG)fe_bu)Qf;{mQlaO+%HgK z3i=hu3k9o&Ix8d7)g8i_eJJ|7gA!jul_1zjAP+}5v~r1tB#T;kvr-+r)zhg71HVNiL zmF{}{x6vPfm7O(5nt}J5!1Y|VLNptuCRV0s`crB_(rV`IVwzL;9{yPHIMoEd|!Y^1eXOKcIgY!Ay z)lp)~;YymmUauR4YQA*4TC5f|mD70hg3zKk^H+J;>E2VJg*}n)!^_1oA|}NfWi*>k z#9dB*EsCYnz|9OlPebd|p=k!RDfLpi>i>d}QJ7p!>&0`)Jed3|DsxepcJd0Al__AB z?=9W==^XF<6Ms{RB^jD%%I(6%h+0cWjj=A0YM$gRNE7M^u-CvdcEW-jcWmaWXsyFoYa z1wF0XSw9~H11+;`EguF$$emyvG)kOcL(h02`fVd|6Ysf^c!D2oJ~M+W|7}$m=On_G^CG2}++&jg-jQufF_MJRB5D@&yw214pETU}8ReyKy&fUARj73$IQhB6)S0*Au zPiMEB7PXM4KO@HJb+8z8;WHOn^*#h)1S~NE{=%5BgeCmB8<>fiSkH~XdSO^bU}HtL z@@|Vl$fz!5%vXYu+yha_>L;+;s99q?<1;=-@@sQuf$RpM=NwXJPBxn+x}fvKyr5$N zb3S((BXkmj7O;PcaE5d+#pz9A(2TLG>}q-}r1YREHw@Dvt;109qEP<#;=xoDbsA-= zvzHbklFGU#_I^~}Bo4FaiI5d?X;mVdsP05W^;C*5N%NOXd0`l5QB{TEOXHtgJ6SqC zN#jFN9PAuDda@HgydPl{LSD_pi?DbxBeuXZMxc$`Ode7Hk>wvV64 z0oJ$S?8SoJ{g@Qd3(8)SF%l+-lZrMt}$jt!{+QMYX~?Dn2{T>kGZ5t z6niXRL))^~$>K~(`6!afTTm*xNCp;gvVK1C!vC^E=`5mLAf0{~=4CR?NbZHI~eUjfLMh7j|h!J{F3KmeE#LC+v81I)d zSG}dVAkF(Wzb}#`Ee_OL7Tp(FD9#LiYta2hu`B`esIKXzltqZ3PM#g3{bDIx2WTeb zfu$9$k4=rQVB=D3x<2v>Emfbo#-2WU+QUbnCHBn$>$LS@Uta=a#v48}YwO6Klf^`@PGgb;o6|5HD;*>qFWLzXKjFXfrxv^(`udGE z*HsWUyRjQ<(>UttKI~~nM?kc*QBK{P^)YJDIC?f0@>=6!Hh19}k_O$Yt=T%fqsG&^ z4}GKY;qe*=bSV_)>mMW;2dzd4DK`E9D&|un`{Va2zf}HpA65iwWw~AhOERpuTga4J_Idm z)Rw-37?rEfV7n;veGE31>h}5bED5#p2K6@Y-r1A%LIzv9=E!8{I@gPYmSCW9Px^bChLO=SI?2u)&8U+PeWR>c)ZQvHuJj}`B*C~p2Rg2?F zb;+t_&<<2o58`PA=S5D2RyUl0*vO(gIn%ZO=^4^eeV|AEVyW zM#?Gt01jwY;*x*D4?sfll~eu#mr&I`vpWLCyDnFktIB16RW;FI&?PW_eD)WA58U7Q zu{|94cmPu&;e^wagt%HM&8*O($lIx%Iib_=PU>b}=uz^BaF=_p2=|0{VTV4aKai3C z4!n@3BWeOJL>U+58SwshMU`tIf$cqs)FR2JkBedvqwe;=Sa7I^Km-XXC!xh@X!Fi% zGNIh&&Z|!72zTXOu1~+9p}+p_a-RpUC`f|W#35k^;^~>)?^AL5B+~Kh#Y_q{D^k8b ze)(m0AI7&lDs^%!_P6)S%ZF)nsl>C~)vzZ!B)@5ZDQ{yXwG>J)xIq)#U_YxA7TWuFp#G?mY~+ zjH{|rNE0xY&WiDx38j^`(5h&60r8PdB}K>M^yhje*eJO2ZxHIEUB&A{OC zt_?JQ0t1f<(5TzEuAqeGc2gE8cR+q0*DbkIJMv*oR&H%!W4NT}ejc>^xv%|sP(!hC z>y7o3bsn5|>Hvz3&b#QP?(9N)+iBmJ@R$8!97Du?Hqj)9<=->|STT@|z5{z{*nMlld6Xa0inaH)g2cM&U zlL4l4RduGH7Slv4cFLfY7(d}HGf>4uLz=bn>y$A)JbRf>C#Y5J;Mv6sF}^H|>DlvOl-_v~TawJsMccdmpaywWkq%f&X1>ig}a? zXo_9NvVvDB@_oi$R#Cd%@fqVq%vjS`nujxfUa+Z*c=JB;Ylz4U)~zfSX$rd61@H;% zN~{8j>5IKkVeSUhv4??2y919lJENfW?cncUTZO&2NU_-_MXs8PHhWh+@Zbs!bufbl erX>EfG@iB&zb?Vxu)W?lsHTD@0(dj4Z~qG)R}Fdq literal 0 HcmV?d00001 diff --git a/semilearn/algorithms/srflexmatch/srflexmatch.py b/semilearn/algorithms/srflexmatch/srflexmatch.py new file mode 100644 index 0000000..c5dafa0 --- /dev/null +++ b/semilearn/algorithms/srflexmatch/srflexmatch.py @@ -0,0 +1,152 @@ + +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +import os +import torch +import numpy as np +from .utils import FlexMatchThresholdingHook +from semilearn.core import AlgorithmBase +from semilearn.core.utils import ALGORITHMS +from semilearn.algorithms.hooks import PseudoLabelingHook +from semilearn.algorithms.utils import SSL_Argument, str2bool +from semilearn.algorithms.SemiReward import Rewarder,Generator,SemiReward_infer,SemiReward_train +CUDA_LAUNCH_BLOCKING=1 +@ALGORITHMS.register('srflexmatch') +class FlexMatch(AlgorithmBase): + """ + FlexMatch algorithm (https://arxiv.org/abs/2110.08263). + + Args: + - args (`argparse`): + algorithm arguments + - net_builder (`callable`): + network loading function + - tb_log (`TBLog`): + tensorboard logger + - logger (`logging.Logger`): + logger to use + - T (`float`): + Temperature for pseudo-label sharpening + - p_cutoff(`float`): + Confidence threshold for generating pseudo-labels + - hard_label (`bool`, *optional*, default to `False`): + If True, targets have [Batch size] shape with int values. If False, the target is vector + - ulb_dest_len (`int`): + Length of unlabeled data + - thresh_warmup (`bool`, *optional*, default to `True`): + If True, warmup the confidence threshold, so that at the beginning of the training, all estimated + learning effects gradually rise from 0 until the number of unused unlabeled data is no longer + predominant + + """ + def __init__(self, args, net_builder, tb_log=None, logger=None): + super().__init__(args, net_builder, tb_log, logger) + # flexmatch specified arguments + self.init(T=args.T, p_cutoff=args.p_cutoff, hard_label=args.hard_label, thresh_warmup=args.thresh_warmup) + self.it=0 + self.rewarder = Rewarder(128,384).cuda(self.gpu) + self.generator = Generator(384).cuda(self.gpu) + self.starttiming=20000 + self.rewarder_optimizer = torch.optim.Adam(self.rewarder.parameters(), lr=0.0005) + self.generator_optimizer = torch.optim.Adam(self.generator.parameters(), lr=0.0005) + self.criterion = torch.nn.MSELoss() + + self.semi_reward_infer = SemiReward_infer(self.rewarder, self.starttiming) + self.semi_reward_train = SemiReward_train(self.rewarder, self.generator, self.criterion, self.starttiming,self.gpu) + def init(self, T, p_cutoff, hard_label=True, thresh_warmup=True): + self.T = T + self.p_cutoff = p_cutoff + self.use_hard_label = hard_label + self.thresh_warmup = thresh_warmup + + def set_hooks(self): + self.register_hook(PseudoLabelingHook(), "PseudoLabelingHook") + self.register_hook(FlexMatchThresholdingHook(ulb_dest_len=self.args.ulb_dest_len, num_classes=self.num_classes, thresh_warmup=self.args.thresh_warmup), "MaskingHook") + super().set_hooks() + + def train_step(self, x_lb, y_lb, idx_ulb, x_ulb_w, x_ulb_s): + num_lb = y_lb.shape[0] + + # inference and calculate sup/unsup losses + with self.amp_cm(): + if self.use_cat: + inputs = torch.cat((x_lb, x_ulb_w, x_ulb_s)) + outputs = self.model(inputs) + logits_x_lb = outputs['logits'][:num_lb] + logits_x_ulb_w, logits_x_ulb_s = outputs['logits'][num_lb:].chunk(2) + feats_x_lb = outputs['feat'][:num_lb] + feats_x_ulb_w, feats_x_ulb_s = outputs['feat'][num_lb:].chunk(2) + else: + outs_x_lb = self.model(x_lb) + logits_x_lb = outs_x_lb['logits'] + feats_x_lb = outs_x_lb['feat'] + outs_x_ulb_s = self.model(x_ulb_s) + logits_x_ulb_s = outs_x_ulb_s['logits'] + feats_x_ulb_s = outs_x_ulb_s['feat'] + with torch.no_grad(): + outs_x_ulb_w = self.model(x_ulb_w) + logits_x_ulb_w = outs_x_ulb_w['logits'] + feats_x_ulb_w = outs_x_ulb_w['feat'] + feat_dict = {'x_lb':feats_x_lb, 'x_ulb_w':feats_x_ulb_w, 'x_ulb_s':feats_x_ulb_s} + + sup_loss = self.ce_loss(logits_x_lb, y_lb, reduction='mean') + + # probs_x_ulb_w = torch.softmax(logits_x_ulb_w, dim=-1) + probs_x_ulb_w = self.compute_prob(logits_x_ulb_w.detach()) + + if self.registered_hook("DistAlignHook"): + probs_x_ulb_w = self.call_hook("dist_align", "DistAlignHook", probs_x_ulb=probs_x_ulb_w.detach()) + mask = self.call_hook("masking", "MaskingHook", logits_x_ulb=probs_x_ulb_w, softmax_x_ulb=False, idx_ulb=idx_ulb) + pseudo_label = self.call_hook("gen_ulb_targets", "PseudoLabelingHook", + logits=probs_x_ulb_w, + use_hard_label=self.use_hard_label, + T=self.T, + softmax=False) + + filtered_pseudo_labels = self.semi_reward_infer(feats_x_ulb_w, pseudo_label, self.it) + for filtered_pseudo_label in filtered_pseudo_labels: + unsup_loss = self.consistency_loss(logits_x_ulb_s, + filtered_pseudo_label, + 'ce', + mask=mask) + + if self.it > 0: + generator_loss, rewarder_loss = self.semi_reward_train(feats_x_ulb_w, pseudo_label, y_lb, self.it) + self.generator_optimizer.zero_grad() + self.rewarder_optimizer.zero_grad() + generator_loss.backward(retain_graph=True) + rewarder_loss.backward(retain_graph=True) + self.generator_optimizer.step() + self.rewarder_optimizer.step() + total_loss = sup_loss + self.lambda_u * unsup_loss + + out_dict = self.process_out_dict(loss=total_loss, feat=feat_dict) + log_dict = self.process_log_dict(sup_loss=sup_loss.item(), + unsup_loss=unsup_loss.item(), + total_loss=total_loss.item(), + util_ratio=mask.float().mean().item()) + return out_dict, log_dict + + + def get_save_dict(self): + save_dict = super().get_save_dict() + # additional saving arguments + save_dict['classwise_acc'] = self.hooks_dict['MaskingHook'].classwise_acc.cpu() + save_dict['selected_label'] = self.hooks_dict['MaskingHook'].selected_label.cpu() + return save_dict + + def load_model(self, load_path): + checkpoint = super().load_model(load_path) + self.hooks_dict['MaskingHook'].classwise_acc = checkpoint['classwise_acc'].cuda(self.gpu) + self.hooks_dict['MaskingHook'].selected_label = checkpoint['selected_label'].cuda(self.gpu) + self.print_fn("additional parameter loaded") + return checkpoint + + @staticmethod + def get_argument(): + return [ + SSL_Argument('--hard_label', str2bool, True), + SSL_Argument('--T', float, 0.5), + SSL_Argument('--p_cutoff', float, 0.95), + SSL_Argument('--thresh_warmup', str2bool, True), + ] diff --git a/semilearn/algorithms/srflexmatch/utils.py b/semilearn/algorithms/srflexmatch/utils.py new file mode 100644 index 0000000..2c32d14 --- /dev/null +++ b/semilearn/algorithms/srflexmatch/utils.py @@ -0,0 +1,67 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +import torch +from copy import deepcopy +from collections import Counter + +from semilearn.algorithms.hooks import MaskingHook + + +class FlexMatchThresholdingHook(MaskingHook): + """ + Adaptive Thresholding in FlexMatch + """ + def __init__(self, ulb_dest_len, num_classes, thresh_warmup=True, *args, **kwargs): + super().__init__(*args, **kwargs) + self.ulb_dest_len = ulb_dest_len + self.num_classes = num_classes + self.thresh_warmup = thresh_warmup + self.selected_label = torch.ones((self.ulb_dest_len,), dtype=torch.long, ) * -1 + self.classwise_acc = torch.zeros((self.num_classes,)) + + @torch.no_grad() + def update(self, *args, **kwargs): + pseudo_counter = Counter(self.selected_label.tolist()) + if max(pseudo_counter.values()) < self.ulb_dest_len: # not all(5w) -1 + if self.thresh_warmup: + for i in range(self.num_classes): + self.classwise_acc[i] = pseudo_counter[i] / max(pseudo_counter.values()) + else: + wo_negative_one = deepcopy(pseudo_counter) + if -1 in wo_negative_one.keys(): + wo_negative_one.pop(-1) + for i in range(self.num_classes): + self.classwise_acc[i] = pseudo_counter[i] / max(wo_negative_one.values()) + + @torch.no_grad() + def masking(self, algorithm, logits_x_ulb, idx_ulb, softmax_x_ulb=True, *args, **kwargs): + if not self.selected_label.is_cuda: + self.selected_label = self.selected_label.to(logits_x_ulb.device) + if not self.classwise_acc.is_cuda: + self.classwise_acc = self.classwise_acc.to(logits_x_ulb.device) + + if softmax_x_ulb: + # probs_x_ulb = torch.softmax(logits_x_ulb.detach(), dim=-1) + probs_x_ulb = self.compute_prob(logits_x_ulb.detach()) + else: + # logits is already probs + probs_x_ulb = logits_x_ulb.detach() + max_probs, max_idx = torch.max(probs_x_ulb, dim=-1) + # mask = max_probs.ge(p_cutoff * (class_acc[max_idx] + 1.) / 2).float() # linear + # mask = max_probs.ge(p_cutoff * (1 / (2. - class_acc[max_idx]))).float() # low_limit + mask = max_probs.ge(algorithm.p_cutoff * (self.classwise_acc[max_idx] / (2. - self.classwise_acc[max_idx]))) # convex + # mask = max_probs.ge(p_cutoff * (torch.log(class_acc[max_idx] + 1.) + 0.5)/(math.log(2) + 0.5)).float() # concave + select = max_probs.ge(algorithm.p_cutoff) + mask = mask.to(max_probs.dtype) + + # update + if idx_ulb[select == 1].nelement() != 0: + self.selected_label[idx_ulb[select == 1]] = max_idx[select == 1] + self.update() + + return mask + + + + diff --git a/semilearn/algorithms/srfreematch/__init__.py b/semilearn/algorithms/srfreematch/__init__.py new file mode 100644 index 0000000..775a5f5 --- /dev/null +++ b/semilearn/algorithms/srfreematch/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +from .srfreematch import FreeMatch +from .utils import FreeMatchThresholingHook \ No newline at end of file diff --git a/semilearn/algorithms/srfreematch/__pycache__/__init__.cpython-39.pyc b/semilearn/algorithms/srfreematch/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a3a01b75af5f78c34675a3dda5a0aa4b48abdd45 GIT binary patch literal 262 zcmYjKJ8r^25Z$$%2vF!#Z~+&v7eFW#iXxE|Ktp3K_TU||pJ>-6A_w6PoCwJ+Rj!aK zvnf!XH18uly^+J=fDzVjm+C}%jp08M3Qu%lL{zbgYi9U{^XLFw&B z?b)85VRa8BZh8CC~s=ozfj}!Q$XX(D9K-<;u>d>*2;fb zSJySBb&ZzEa!jFB6?me{X?Y=*B*+ZSoLIT>dIb(m=&ore#{Has7xL=r;gk zSH||w(4`(}gz+AhSJ;bsUJAf?;6wDl)mv}UTw)zj{q6qqt=Lar6=%zjQFc3WlUr9__2&*ws z$FHwiguzomneOIOG%{_ZooGWX!>-n}AA$NF3&SEXW>y4E zfg5WYH;?rZwMr*!SjzB@tP|~=23Q}Jxxv|qKGZW~3LB}-jm$Ww9kJ0IzG`K9W;L{7 zIm6qr@TCGT9_yH)tSC!UN;@+<4lm{XnT;`X8QwEv=CX3Ck1E+5w^p+x1ank9y%c&c7+aZ04{gKIKt_q`#Vit(1FX+K@jwbgk+(;@Nb*n z{_daO|5CFmO_D5GN_u-~*ZV*gAW;6kGzhkK!`SZ)AcBQt-|GibPkXY!(}w_!WYCo* zl6G{cGm=Ffq#nka-N1{b5%t>53L0K=(6Z17)5J|$c3<>%-ENNuk+dTcdv`xf(pFL7 z6Up3`(-#4E^S-22q;3pcm<@Cj--`m#s$v>rywYs_MGqo&t^t22j*7in~PD%(DdI{i%K zHu*hZ1Fgll*3$XnN7^3blN*o9ErT~86V@ZvG8L9Ak<}`UEm2Y7D!d9@Nx}2H2G|P9 ze1X>=K?+;;N1Ca%<{%AAvOG0K+dk>Y4korz`kZ$M~5C5 zC?;tKiqZ?tjtl5z9V^pZl4-2_Nbk%6<3M|$b8aQA0>{qa=6L}`1#Rk{3`HkqjgBX=APqZugcraf@C|=-iJlZF&MXyaQR6Wgq?`5MeI{Gfi0WJ!yKP zoyfU3NZs8*7?HCnrn}fwTNY8Z1Mvw=jO+T5mn5z`QzP+tEbvv7WEGXh>g)paWtlZt zkrj2D*~ThcWry{tjBZW1y7RM_KU)6BpZxc~zj@0q(%A7sFxfQaxJV_*#Jd`VR&yBy zS9A5uL~UdwY^D+m2C-1*Bda6@GvWBmQ8kO=HfH>#NBJ2WA$|qKH&GH5&Cr(Adsv-q zheYEk$4h%qYWD>IA31P6SglQN)%y+G--jtPgPA49pQCoMufOQbMrPv*9uX_`8Y zWzfwzA0*Q1_9#MtC7!zrGdqboELjmj8)h>Q?mog0HCE6X#5+{IO_h?{*8$0Lf|Dk? zq{(uY+#v!ogMzH1*g^F*ltdj3tFm==xH#MObod#ILza<5UTys?*uhNiVB7G|II0ml z(LOf5^H0hi=;BXPV+5N+mT1p_ABN2s6|!RPk+}{pt!0Hj!Y{mSiP~9#Yt$YU!7c{B zULM-$F;7v>1wtPuzDotMD?{wSi@bDf z4C`o@xy2Dfa2xGK)kZXms6y-5D094WTo^V0IVhC_^J6Cdqpb~>+S(s8*x%|AdrSM) zOT%UCU>@ITAZ}U7mPR!|jA@SAYL@TS9Yb|Y13-eqC zPuDSiC0jix;7i}n)-Zl4TVI=u2L^Pj!*f{8Qf6o8_!bB^hSeM5FR;?(%sFAB6~s~J5pS)ccMa`zz5woC$S#b|K~V@gKgOL0l;h5y+2;!S z>!4f#H#Fi1eFvq2)f?h9@Mlr&>kDW3^H?MPj5bcR(Rq{$Id?w8EB*eX^y28!W2^+e zTFKq;qVmN{eEE@;U3?2(c)0cvgHK;Mg;`0B&U4us>dSnUui;cLDGcEWv(C?fTPtU| zkUpQC&n_M5!{=gK!50Vy&6Nz}Mpv^H@Mp{?obctmCwmU?^VsjT(F%(gp`D#`tnanm+eO1+qqt|d2ucM%c zUj&a{$gX8CsCf6%wsO|W5h0*_Le4>&djXOZLz*AB(XK3g;0~f)_x%Zw2yz@06(g)1 z_;AW_e-xF%T#9$|E@vE{s-bcoUQ^8h7u*igB)7eW_-zo2Z%{?<{|Pi&le&+jPFx}6 zB2~1YCuC*ApCc|m;|!>a;z6QQSy&sq`<63gkKxwVv_I4DoA<_Mmu&ARy41H`ku1d0 z*F4`JbO(`_A~No)lrax@#*wTFV(l9dY`)*?_TTtE#{U*q{oB4inydbh-YxxOTB1;)m3}{nIby~%=!%PkGLVXI3<3K zMwt)7;C)%f{_YQg;4nbGHUiUn#dAnRKVO+ zMbTTGHgZ)z+fJlS79nw!J*Z6Job*f+r-}K;gw7Zwx=Y68k$FZFj9NyM3npJfN?=r?Cw^!MH zLnoYFT37Shk(+g2TqM%7wIK1x>#io{?tLfoOgSw=wC{2x%pxsa_lkmNPk;mV8+XTIo@@0&am zE+|frL`?1?5Wne&h{1`GN{Tv&!@&onMQfT{L=W!kp3#GwI*PuZ(SsWc+}jl>;?3x` zthlXMlrw9(-NSPp=OM)fUUX@iP+!8$Cl7ej&N*mPb8G>XT~5%-u)(_ZvUB{m+Mw-aM>sZl1o7xpi3C+&pzB10Izy zgRrPlZbC`gtWp-kA&G)_Q;k>yfAkmhQh?Wd_%iYj5vXyyY_l<04LEaDh z`5+t^-N}aeC>$A?W^4I49E02o*MTP74L8h+7oy)L+~fXB!hOMxHeXobmH)OYm>2Lq z%VW?%cgA%S0$uO6{ZiUCBiT zBi-GyE-E2Ue}ovPH-IHT7aptA)O8R92`MKbdqrlHQ_fy`p~Wq3za*jkir6G{U`LMj zZ;L|6xGH5rSAwBD0K$;!2VfPcXhUAGIa@&TD{F28IT(1sK7 zJljbgKa61@MENw2(_%-7Jk20tu@h$prA({CTRas%vBCUFr>pUSK7 zZrK}Xac8A`5=;ILC<-qSfeoCjpD*0`U-vMBMVt%BU=T%l$?FW|K@>fzw0uBFqQ-?7alwzKK}=n-aUb*5;9rN#t10m8S$-t z0`_lIHjL8+1AIA%N{|7ykqffN8DthcK6)_AAO|4ZfMP%`X#GwD%z`#B8hDL1?0kS< z@QsY-bo16aEVl^eurZ03M`F8^ePGkE9HM*;85UZuAw%>)3lMQtOA*Zy?e8IS_e+`U z!O~oy=6{nv6r88UfnLkvha!u_nIUKmX51u^CBQtcD!d$JfkE9oJBIh?OW_89W`-WH zw8D+4)#7X8QtXJQx$ZA_0*W-%HyAV!z_WY+w7dh>hz?o;J_0&fUsm$`3apkCv*Q5m zVvt{fty}>zhihOi1DUgifndX059%dLWWA7%|2cY*7AAP*Z4rOq=#4RpV<=WE~{H@?vW(AS#) zJV6t{#kKaOX<<}bkhJMUz6U3h*O9%C3{$2ZY`TZqO5#fUb)lf4i4&n+RnDqBekwo2 zL6)j>O>b#iimaAXocs~8kJ0I>^dK)&EaL>W!tkef}+$2vPY3j<)XJ*^}6)!!0xLajM{6ZC5fb6l^3{nmxFs2WL8bcqH~H ziKy27q|7p5zVJ%+&@x$#0$o^SO82_ASJZiwqa~fgF37brehy4Rkt7BdO literal 0 HcmV?d00001 diff --git a/semilearn/algorithms/srfreematch/__pycache__/utils.cpython-39.pyc b/semilearn/algorithms/srfreematch/__pycache__/utils.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ccdfdc4822e2c0217f468db53b3c1f2b1ba606dd GIT binary patch literal 2155 zcmZ`)&2L;a5V!rl-|nVux@yz35<#jeR;AeA5Gc zlWg^=Nq1mO@9nTk^vQDzyuvwB78}9 zsj!576bx+X%-lJgD%>M-NC)0>-V(kDjwnd{m&7K60B+MZYo917A4g?8+8!yXM@70k z{a`)ZfcEsm?TsYgIMZr`zutX$?dKo=`29g_H5XbC2N+)Q5_J6;kdhpdF|DnH)b{uE zYf4E?YrMZhT6D$&yl~PCg6)hIDbX{eArKu>L7_nG9UKG&tLT8XFCyCNpg8F zi;}#rWtOBeQh7g0hlNVYQKtL5;PDJS-ml6e)ti$S%7a5!&~0rXgsxJTx~xmP^x(}i z&tHho=IP7-6$2NxWB&-cUZ(^+J|=1nC>3Z{%yFv}S&731`l_3U6*vn=j7(; zEpWYEQ;sLD0>`-v+--2w+ByLTk(0lmZTFb5M3NbAPs#}lhTnaQ?c~7_RGK#2ZQ5@S z6z+OO)jBXp*&s{u6aY|;dR@~JfGm~lRHYOqd|s$jaGe}T#XwnMaWPgF?VEO`CI705 zazKDom@z8AMgL5Oa4@YnO(wi-GSk>NQHT)Z?<9F#RC#GUOxkE7At4&Z%ECCJd;!VP zK}}0nSls>VLB|5}1iJnK2!slXhh9TI z1JI5wEEF8&HpuTIKPR_g?tuIe@^f;hw$zuUJN9aZMMO}+j@i`B{l$C=fSd(TmfSg1 z@7f)QjRu8;@TNZ4BCt0^zSt?O8vw6qPGs9)9Wlk=OvoQm59_^uwq;xpC$1kRtp$381ZzmTOv+iy^P|Kx{qcya0eU0M*)Lg zs|l*Qi^6tV43qLSo}K0HMR?UJXsD}5-bQi{$tDm$Vj-Xk!jE$?&;Wzx8$z{ZqjjX&Lb$f+=Lvfa+Tt^%`qL)<%-9- zC}Pf=!njBqFE98|MPlKKck%w6Wi5^+e4+f@8@jNn5Tyzsh8}GSt50ofC2(nY&85u> zBu3-z^|Ly?d6pu2v#HKa;+oehjNp^j3ctX{+PMofdDEnf|K8$>@8eB0mS%7qZCif> D$d4x^ literal 0 HcmV?d00001 diff --git a/semilearn/algorithms/srfreematch/srfreematch.py b/semilearn/algorithms/srfreematch/srfreematch.py new file mode 100644 index 0000000..0f6d6ad --- /dev/null +++ b/semilearn/algorithms/srfreematch/srfreematch.py @@ -0,0 +1,170 @@ +import torch +import torch.nn.functional as F +import numpy as np +from .utils import FreeMatchThresholingHook +from semilearn.core import AlgorithmBase +from semilearn.core.utils import ALGORITHMS +from semilearn.algorithms.hooks import PseudoLabelingHook +from semilearn.algorithms.utils import SSL_Argument, str2bool +from semilearn.algorithms.SemiReward import Rewarder,Generator,SemiReward_infer,SemiReward_train + +# TODO: move these to .utils or algorithms.utils.loss +def replace_inf_to_zero(val): + val[val == float('inf')] = 0.0 + return val + +def entropy_loss(mask, logits_s, prob_model, label_hist): + mask = mask.bool() + + # select samples + logits_s = logits_s[mask] + + prob_s = logits_s.softmax(dim=-1) + _, pred_label_s = torch.max(prob_s, dim=-1) + + hist_s = torch.bincount(pred_label_s, minlength=logits_s.shape[1]).to(logits_s.dtype) + hist_s = hist_s / hist_s.sum() + + # modulate prob model + prob_model = prob_model.reshape(1, -1) + label_hist = label_hist.reshape(1, -1) + # prob_model_scaler = torch.nan_to_num(1 / label_hist, nan=0.0, posinf=0.0, neginf=0.0).detach() + prob_model_scaler = replace_inf_to_zero(1 / label_hist).detach() + mod_prob_model = prob_model * prob_model_scaler + mod_prob_model = mod_prob_model / mod_prob_model.sum(dim=-1, keepdim=True) + + # modulate mean prob + mean_prob_scaler_s = replace_inf_to_zero(1 / hist_s).detach() + # mean_prob_scaler_s = torch.nan_to_num(1 / hist_s, nan=0.0, posinf=0.0, neginf=0.0).detach() + mod_mean_prob_s = prob_s.mean(dim=0, keepdim=True) * mean_prob_scaler_s + mod_mean_prob_s = mod_mean_prob_s / mod_mean_prob_s.sum(dim=-1, keepdim=True) + + loss = mod_prob_model * torch.log(mod_mean_prob_s + 1e-12) + loss = loss.sum(dim=1) + return loss.mean(), hist_s.mean() + + +@ALGORITHMS.register('srfreematch') +class FreeMatch(AlgorithmBase): + def __init__(self, args, net_builder, tb_log=None, logger=None): + super().__init__(args, net_builder, tb_log, logger) + self.init(T=args.T, hard_label=args.hard_label, ema_p=args.ema_p, use_quantile=args.use_quantile, clip_thresh=args.clip_thresh) + self.lambda_e = args.ent_loss_ratio + self.it=0 + self.rewarder = Rewarder(128,384).cuda(self.gpu) + self.generator = Generator(384).cuda(self.gpu) + self.starttiming=20000 + self.rewarder_optimizer = torch.optim.Adam(self.rewarder.parameters(), lr=0.0005) + self.generator_optimizer = torch.optim.Adam(self.generator.parameters(), lr=0.0005) + self.criterion = torch.nn.MSELoss() + + self.semi_reward_infer = SemiReward_infer(self.rewarder, self.starttiming) + self.semi_reward_train = SemiReward_train(self.rewarder, self.generator, self.criterion, self.starttiming,self.gpu) + def init(self, T, hard_label=True, ema_p=0.999, use_quantile=True, clip_thresh=False): + self.T = T + self.use_hard_label = hard_label + self.ema_p = ema_p + self.use_quantile = use_quantile + self.clip_thresh = clip_thresh + + + def set_hooks(self): + self.register_hook(PseudoLabelingHook(), "PseudoLabelingHook") + self.register_hook(FreeMatchThresholingHook(num_classes=self.num_classes, momentum=self.args.ema_p), "MaskingHook") + super().set_hooks() + + + def train_step(self, x_lb, y_lb, x_ulb_w, x_ulb_s): + num_lb = y_lb.shape[0] + + # inference and calculate sup/unsup losses + with self.amp_cm(): + if self.use_cat: + inputs = torch.cat((x_lb, x_ulb_w, x_ulb_s)) + outputs = self.model(inputs) + logits_x_lb = outputs['logits'][:num_lb] + logits_x_ulb_w, logits_x_ulb_s = outputs['logits'][num_lb:].chunk(2) + feats_x_lb = outputs['feat'][:num_lb] + feats_x_ulb_w, feats_x_ulb_s = outputs['feat'][num_lb:].chunk(2) + else: + outs_x_lb = self.model(x_lb) + logits_x_lb = outs_x_lb['logits'] + feats_x_lb = outs_x_lb['feat'] + outs_x_ulb_s = self.model(x_ulb_s) + logits_x_ulb_s = outs_x_ulb_s['logits'] + feats_x_ulb_s = outs_x_ulb_s['feat'] + with torch.no_grad(): + outs_x_ulb_w = self.model(x_ulb_w) + logits_x_ulb_w = outs_x_ulb_w['logits'] + feats_x_ulb_w = outs_x_ulb_w['feat'] + feat_dict = {'x_lb':feats_x_lb, 'x_ulb_w':feats_x_ulb_w, 'x_ulb_s':feats_x_ulb_s} + + + sup_loss = self.ce_loss(logits_x_lb, y_lb, reduction='mean') + + # calculate mask + mask = self.call_hook("masking", "MaskingHook", logits_x_ulb=logits_x_ulb_w) + + + # generate unlabeled targets using pseudo label hook + pseudo_label = self.call_hook("gen_ulb_targets", "PseudoLabelingHook", + logits=logits_x_ulb_w, + use_hard_label=self.use_hard_label, + T=self.T) + filtered_pseudo_labels = self.semi_reward_infer(feats_x_ulb_w, pseudo_label, self.it) + for filtered_pseudo_label in filtered_pseudo_labels: + unsup_loss = self.consistency_loss(logits_x_ulb_s, + filtered_pseudo_label, + 'ce', + mask=mask) + + if self.it > 0: + generator_loss, rewarder_loss = self.semi_reward_train(feats_x_ulb_w, pseudo_label, y_lb, self.it) + self.generator_optimizer.zero_grad() + self.rewarder_optimizer.zero_grad() + generator_loss.backward(retain_graph=True) + rewarder_loss.backward(retain_graph=True) + self.generator_optimizer.step() + self.rewarder_optimizer.step() + # calculate entropy loss + if mask.sum() > 0: + ent_loss, _ = entropy_loss(mask, logits_x_ulb_s, self.p_model, self.label_hist) + else: + ent_loss = 0.0 + # ent_loss = 0.0 + total_loss = sup_loss + self.lambda_u * unsup_loss + self.lambda_e * ent_loss + + out_dict = self.process_out_dict(loss=total_loss, feat=feat_dict) + log_dict = self.process_log_dict(sup_loss=sup_loss.item(), + unsup_loss=unsup_loss.item(), + total_loss=total_loss.item(), + util_ratio=mask.float().mean().item()) + return out_dict, log_dict + + def get_save_dict(self): + save_dict = super().get_save_dict() + # additional saving arguments + save_dict['p_model'] = self.hooks_dict['MaskingHook'].p_model.cpu() + save_dict['time_p'] = self.hooks_dict['MaskingHook'].time_p.cpu() + save_dict['label_hist'] = self.hooks_dict['MaskingHook'].label_hist.cpu() + return save_dict + + + def load_model(self, load_path): + checkpoint = super().load_model(load_path) + self.hooks_dict['MaskingHook'].p_model = checkpoint['p_model'].cuda(self.args.gpu) + self.hooks_dict['MaskingHook'].time_p = checkpoint['time_p'].cuda(self.args.gpu) + self.hooks_dict['MaskingHook'].label_hist = checkpoint['label_hist'].cuda(self.args.gpu) + self.print_fn("additional parameter loaded") + return checkpoint + + @staticmethod + def get_argument(): + return [ + SSL_Argument('--hard_label', str2bool, True), + SSL_Argument('--T', float, 0.5), + SSL_Argument('--ema_p', float, 0.999), + SSL_Argument('--ent_loss_ratio', float, 0.01), + SSL_Argument('--use_quantile', str2bool, False), + SSL_Argument('--clip_thresh', str2bool, False), + ] \ No newline at end of file diff --git a/semilearn/algorithms/srfreematch/utils.py b/semilearn/algorithms/srfreematch/utils.py new file mode 100644 index 0000000..6442bfd --- /dev/null +++ b/semilearn/algorithms/srfreematch/utils.py @@ -0,0 +1,69 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +import torch + +from semilearn.algorithms.utils import concat_all_gather +from semilearn.algorithms.hooks import MaskingHook + + +class FreeMatchThresholingHook(MaskingHook): + """ + SAT in FreeMatch + """ + def __init__(self, num_classes, momentum=0.999, *args, **kwargs): + super().__init__(*args, **kwargs) + self.num_classes = num_classes + self.m = momentum + + self.p_model = torch.ones((self.num_classes)) / self.num_classes + self.label_hist = torch.ones((self.num_classes)) / self.num_classes + self.time_p = self.p_model.mean() + + @torch.no_grad() + def update(self, algorithm, probs_x_ulb): + if algorithm.distributed and algorithm.world_size > 1: + probs_x_ulb = concat_all_gather(probs_x_ulb) + max_probs, max_idx = torch.max(probs_x_ulb, dim=-1,keepdim=True) + + if algorithm.use_quantile: + self.time_p = self.time_p * self.m + (1 - self.m) * torch.quantile(max_probs,0.8) #* max_probs.mean() + else: + self.time_p = self.time_p * self.m + (1 - self.m) * max_probs.mean() + + if algorithm.clip_thresh: + self.time_p = torch.clip(self.time_p, 0.0, 0.95) + + self.p_model = self.p_model * self.m + (1 - self.m) * probs_x_ulb.mean(dim=0) + hist = torch.bincount(max_idx.reshape(-1), minlength=self.p_model.shape[0]).to(self.p_model.dtype) + self.label_hist = self.label_hist * self.m + (1 - self.m) * (hist / hist.sum()) + + algorithm.p_model = self.p_model + algorithm.label_hist = self.label_hist + algorithm.time_p = self.time_p + + + @torch.no_grad() + def masking(self, algorithm, logits_x_ulb, softmax_x_ulb=True, *args, **kwargs): + if not self.p_model.is_cuda: + self.p_model = self.p_model.to(logits_x_ulb.device) + if not self.label_hist.is_cuda: + self.label_hist = self.label_hist.to(logits_x_ulb.device) + if not self.time_p.is_cuda: + self.time_p = self.time_p.to(logits_x_ulb.device) + + if softmax_x_ulb: + probs_x_ulb = torch.softmax(logits_x_ulb.detach(), dim=-1) + else: + # logits is already probs + probs_x_ulb = logits_x_ulb.detach() + + self.update(algorithm, probs_x_ulb) + + max_probs, max_idx = probs_x_ulb.max(dim=-1) + mod = self.p_model / torch.max(self.p_model, dim=-1)[0] + mask = max_probs.ge(self.time_p * mod[max_idx]).to(max_probs.dtype) + return mask + + + diff --git a/semilearn/algorithms/srpseudolabel/__init__.py b/semilearn/algorithms/srpseudolabel/__init__.py new file mode 100644 index 0000000..24c70b8 --- /dev/null +++ b/semilearn/algorithms/srpseudolabel/__init__.py @@ -0,0 +1,4 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +from .srpseudolabel import PseudoLabel \ No newline at end of file diff --git a/semilearn/algorithms/srpseudolabel/__pycache__/__init__.cpython-39.pyc b/semilearn/algorithms/srpseudolabel/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b39505196c5be2e2f0fdd109db9bd1c005afde06 GIT binary patch literal 214 zcmYe~<>g`kf>8hKDS1HpF^Gc(44TX@8G*u@ zjJLQ0ic?Ed@_iDMQgi$?nQn0xfcZHfei6uwl?+8JKnhI!^3l)9EUn5+&PdHm*Uu`i z)K4x;N-R#z$;?aD%}vbA(=SfV&CE$nEXvbQ%t_BL$}GvqE!IaeQa?UEGcU6wK3=b& V@)n0pZhlH>PO2Tqg3mzA006_QIZ^-s literal 0 HcmV?d00001 diff --git a/semilearn/algorithms/srpseudolabel/__pycache__/main.cpython-39.pyc b/semilearn/algorithms/srpseudolabel/__pycache__/main.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e216376d56a6226f41e40215fc0aa39836d6beb9 GIT binary patch literal 2230 zcmb7F&2Aev5GJ`lT1hKgmL0o4C=j4Y3ms~qY0*?_JJFs0b}QJ| z;a%9*UF_TB^a&Y;Gk6!-A5!D(=E;!&p);Poe-5WcCml9zP6>r|VOpR2wuC#Bo++RU8{u6sGe;JgY@f zWr;NYURH=i8TU}^eQ(<1G?vLrp)>@|)k2OhsA

Dif{a_=5cP%}8eTd6piF;%Ibo z_HvXyf11ET@bXEXWW`8}Jd@zD7$x$kR9SVL>(QweH7{lIRLBvo+CF`$R^Vc5Kzaih zp)1s-9qM1xpXYMT+G15h-U5P=JvuU#L!n7{BbhDDMKh+!v=^ODyRhE7-H z7D-hrfr_%Rwi&kp$J0Ls(zcTV;#TMY=zQ%mTGudJ4jWdG2egA*e+Jf2gxLf-!_<~A zzzs^fa!y=;9)x1zH4G%Dac-@1V2$6<#u-DYbN46q3L3!)6!au$U@tTTw0Vu!xEti8 zZFTy}uR4=1tO*-`ir~gX^gv%}T2{wYtTqA61WkZ_XFg*k zGp_EV=H?CF?A6+DBcY`RPNGIwqz75dX@U74-h<};BU9aq|H-kvDy z3O1;B@qAxdxGz_UfoQsEDW$N#KCOJT43@n>a}rV;fAI~9I*)U-bgB#*n6AVr zWCN|XuS05l_)$VXvkhoosn>n-s(wuH<-F~Eh49jMfkD@@VGSXeAq|-eP26P=;(?uE F{{UYoAiV$p literal 0 HcmV?d00001 diff --git a/semilearn/algorithms/srpseudolabel/__pycache__/pseudolabel.cpython-39.pyc b/semilearn/algorithms/srpseudolabel/__pycache__/pseudolabel.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c0435cbd0859695b5323c089c8f986234a8e7830 GIT binary patch literal 5263 zcmbtYOOG4J5uP3n$>H#^T<&}IqGUT3W$&&mJC1D#R^*rLK$0CvNDv~%cy@Z0JFDRg z)icy%F*Jg#0xL-5IL8DzaDiNO&B4DT@GX7KA<)H0kV_EQPQL0Pmz1oWGDLS*S65e8 zS9euc(WzD~4d0KxXn1ocH0`g{nf~(Vd>EXtzA<%7R+|TMv*`tX*m*aKuFHjY{M#PCwk^DPI|}#+jj5HZS1;P4Fz}!bu__=iDgvL(h)=t{*s9`HmgRa!)N2&@wKGHvJi+PJ9u7rxe%LIcw*5 z?GtT_@f`7T@5NIeWDp!+kzVEvOKjv z=cKM_R&(@a15LH_rZZMj^W=7t^y1exHXL!=zquZX&W5uYZ@lp8=~vcIpLz9_7hCJr z0T1XRe(j*Q)^yNsp8N<+&k?crQEM_tt4>w{g;HWXH97PWd$aEc&?k7fP5^BM-{JwP zTamcl3?he4rn%J*-NcW=DMrbr9Yh`Qxc0`ysPniGi5JF^*o<&+VAAP$VrpuJz*GXj zejN-z@i>nWSrRq-u{SlnXS@9*+S)q#(H7)O9-sDR6mI$43tg|7jBi(SD-z94c6(?) z*+p4XH)UEo{V?wLY+Sdl-S0ubY=P&Sr=lKh-w94V-{jtw(+`rk3BgXUzx0^)uQ=Ua z6LV?7Tb|$91|h+UvnOu)u}2z-vsEv!`IUM zAK9mjX{olOr)(#OVsLg}XIhfyI?wItI|bm1+_=wD%{G$~aO@GB#q$^`^8)G$_?W}$ zzNR=@8I;s^W_IdoWOiTUMQ+~LLEAXMTf=Gm9PsnQg?-3mFLDhAQM&v%43G_6-{;Z~TzIyPFFD`9vyoXy)i67=BNC6D6*r!ll)04Hs9l@kHTLu*;$X zd@CoOK{IuJYt&WMIzjWOPNT%iVMCzZ`m`8y_%2eDDV=u&@c1KePd7srifx*H$G$HY)2c^I@&K8JWX|vy`_D4 zaxeqQt+bR@_w_+Ntqse7Alp3#s6MRhYt*~(c~ZT{_E=J*`j9r51*b+bLp^Xg`cr){ zhuQUX_Ih4?l{PT9md>qYbI}LQ@?aiv)KfE^=e4xLXCU3eXY58ZTuiG}!L?R3FY^bRPW$J`4RVrAxy_URTiK2)CFnj&O^Q?68jU1}N*$ zgGRcbbx_Rfd93s?G&866^zx&cxu=n4hD-a}a2aJK)61h;!fyVYtPT&|!|C#QuKj|} za#@8Y3w-genJ&MDU}Vtvgdv1kn!?nR!#hXP2I`}HnXlj^R}_Zugemh?XbWC`IGY?B zYawZ-i|NXqK3EMc1uqc{n&TA|+e6t%x(fI(_WRWE>CsE_%m_NP^Bt^o z=sw-d!&5NPK6^l`pk#ID55;f;iT=Z3mmOc-!s=oAxbfbUi1;Lsb$U@2mJo7Q2}z z5#+)!RDoLu2U|^%Gk7%AT)aUYFNu|Fxz^Ifi=dV@ggN#$A|_kKB@eL)P3kU%(+06s zujIr3;Co0;@8Jii{C}bQh(9DROm2l7d`!-?166!j8L)fs$t&7#^&7_3k+&k3g_l}i zKP{P$RaP9=?RWculX%?D-i}O^c;elpbvA%+bfa$X^(GMhGS+ zf1~ds`>;F0;dE5OOYo9Xx3}E^iTXys-_H>%$41zo8s|1$2a4kax%!qNO^O@|A+0Rl z$>1bPoM4Pk{J=*3;zzC1vVcI*cDpj)jkp)c>Sk!WQJ9D*2q@Lr5}vo~*_+5}`r)Kg zbiIsOi9+}6eOiG4VilFl3nxS%y+m-q>GiyjBV0k1wpm$J~5MxE^WJHtt=BDFPWQ6_D z#X`In_&pW>$r@I-vpYxemNWyWyU889FY7%KxgH{@sGr!}casSQW|?8~N;L63EYdoj zJ?G!01r2CqQ@l)I3DZetn11=doYAqyn+Z1C)`$j~SCg5!8dEK+>6Tv4)pI7(nf^@y zc!L>wg%KU$4ORy|l>$*P{p(8J(8qAWsOYs^gIN?3vkF?w2G0t9i>!eW{PYH!*9m78 zR~27t#NX17vNJkKf}on!zAhNMi%)iwF3CRqQQ%0fE+(r^5XmNt>w7ssIY0>1_KXfa zQH9I9*4o-Q(dule@5^($Giz%T-6E3}8{o716y>#)n<^Ld7L2xT8 zc(-gAx%y+_PoT=8+#K}t&F=E_{-z0A`epo0q{L{IDVGb8TOpUTY#Z*c9}rxz?Hhe3 z7>yKdn@6r~t3*X51%l#gaUYeeAc;$S7wOq{#KotCs*bWH*MnQMr$kr5xymdkt*)ZP zHB_cQD?j1B2s9(y92t4QvyDBDrlu1ah(=c(MRsq)wLFcc?<2NK{ceNnrbQpMyH b&9j->U#CbRZlE%8zbmkW3X(v;y2<_vIUTp8 literal 0 HcmV?d00001 diff --git a/semilearn/algorithms/srpseudolabel/srpseudolabel.py b/semilearn/algorithms/srpseudolabel/srpseudolabel.py new file mode 100644 index 0000000..41f15b1 --- /dev/null +++ b/semilearn/algorithms/srpseudolabel/srpseudolabel.py @@ -0,0 +1,112 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +import numpy as np +from semilearn.core import AlgorithmBase +from semilearn.core.utils import ALGORITHMS +from semilearn.algorithms.hooks import PseudoLabelingHook, FixedThresholdingHook +from semilearn.algorithms.utils import SSL_Argument +from semilearn.algorithms.SemiReward import Rewarder,Generator,SemiReward_infer,SemiReward_train +import torch +import torch.nn as nn + +@ALGORITHMS.register('srpseudolabel') +class PseudoLabel(AlgorithmBase): + """ + Pseudo Label algorithm (https://arxiv.org/abs/1908.02983). + + Args: + - args (`argparse`): + algorithm arguments + - net_builder (`callable`): + network loading function + - tb_log (`TBLog`): + tensorboard logger + - logger (`logging.Logger`): + logger to use + - p_cutoff(`float`): + Confidence threshold for generating pseudo-labels + - unsup_warm_up (`float`, *optional*, defaults to 0.4): + Ramp up for weights for unsupervised loss + """ + + def __init__(self, args, net_builder, tb_log=None, logger=None, **kwargs): + super().__init__(args, net_builder, tb_log, logger, **kwargs) + self.init(p_cutoff=args.p_cutoff, unsup_warm_up=args.unsup_warm_up) + self.it=0 + self.rewarder = Rewarder(128,384).cuda(self.gpu) + self.generator = Generator(384).cuda(self.gpu) + self.starttiming=20000 + self.rewarder_optimizer = torch.optim.Adam(self.rewarder.parameters(), lr=0.0005) + self.generator_optimizer = torch.optim.Adam(self.generator.parameters(), lr=0.0005) + self.criterion = torch.nn.MSELoss() + + self.semi_reward_infer = SemiReward_infer(self.rewarder, self.starttiming) + self.semi_reward_train = SemiReward_train(self.rewarder, self.generator, self.criterion, self.starttiming,self.gpu) + def init(self, p_cutoff, unsup_warm_up=0.4): + self.p_cutoff = p_cutoff + self.unsup_warm_up = unsup_warm_up + + def set_hooks(self): + self.register_hook(PseudoLabelingHook(), "PseudoLabelingHook") + self.register_hook(FixedThresholdingHook(), "MaskingHook") + super().set_hooks() + + def train_step(self, x_lb, y_lb, x_ulb_w): + # inference and calculate sup/unsup losses + with self.amp_cm(): + + outs_x_lb = self.model(x_lb) + logits_x_lb = outs_x_lb['logits'] + feats_x_lb = outs_x_lb['feat'] + + # calculate BN only for the first batch + self.bn_controller.freeze_bn(self.model) + outs_x_ulb = self.model(x_ulb_w) + logits_x_ulb = outs_x_ulb['logits'] + feats_x_ulb = outs_x_ulb['feat'] + self.bn_controller.unfreeze_bn(self.model) + + feat_dict = {'x_lb': feats_x_lb, 'x_ulb_w': feats_x_ulb} + + sup_loss = self.ce_loss(logits_x_lb, y_lb, reduction='mean') + + # compute mask + mask = self.call_hook("masking", "MaskingHook", logits_x_ulb=logits_x_ulb) + + # generate unlabeled targets using pseudo label hook + pseudo_label = self.call_hook("gen_ulb_targets", "PseudoLabelingHook", + logits=logits_x_ulb, + use_hard_label=True) + filtered_pseudo_labels = self.semi_reward_infer(feats_x_ulb, pseudo_label, self.it) + for filtered_pseudo_label in filtered_pseudo_labels: + unsup_loss = self.consistency_loss(logits_x_ulb, + filtered_pseudo_label, + 'ce', + mask=mask) + + if self.it > 0: + generator_loss, rewarder_loss = self.semi_reward_train(feats_x_ulb, pseudo_label, y_lb, self.it) + self.generator_optimizer.zero_grad() + self.rewarder_optimizer.zero_grad() + generator_loss.backward(retain_graph=True) + rewarder_loss.backward(retain_graph=True) + self.generator_optimizer.step() + self.rewarder_optimizer.step() + unsup_warmup = np.clip(self.it / (self.unsup_warm_up * self.num_train_iter), a_min=0.0, a_max=1.0) + total_loss = sup_loss + self.lambda_u * unsup_loss * unsup_warmup + + out_dict = self.process_out_dict(loss=total_loss, feat=feat_dict) + log_dict = self.process_log_dict(sup_loss=sup_loss.item(), + unsup_loss=unsup_loss.item(), + total_loss=total_loss.item(), + util_ratio=mask.float().mean().item()) + return out_dict, log_dict + + @staticmethod + def get_argument(): + return [ + SSL_Argument('--p_cutoff', float, 0.95), + SSL_Argument('--unsup_warm_up', float, 0.4, 'warm up ratio for unsupervised loss'), + # SSL_Argument('--use_flex', str2bool, False), + ] \ No newline at end of file diff --git a/semilearn/algorithms/srsoftmatch/__init__.py b/semilearn/algorithms/srsoftmatch/__init__.py new file mode 100644 index 0000000..638df27 --- /dev/null +++ b/semilearn/algorithms/srsoftmatch/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +from .srsoftmatch import SoftMatch +from .utils import SoftMatchWeightingHook \ No newline at end of file diff --git a/semilearn/algorithms/srsoftmatch/__pycache__/__init__.cpython-39.pyc b/semilearn/algorithms/srsoftmatch/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e868e68b0db1c421a1a29a452b302670504d2a7e GIT binary patch literal 260 zcmYjKONzok6zopoFp6$nd4VjnFCgMl1a%=IxCw32lRWK@baxaz$6R7gBxLJqt}v^4 z?LsVg^?{-u@AWzav02@VZ`4N!?}#u=@!SBZAO)2~(H*71M>10JE=C`(Nt*mptfaJQ zug#ZVaGzHg`meARP*xsvIWy+COp~VVu%|-LhA&O4_8^^&(f<(pf%2rkIg?P|Ag_cj z+2L|!`FYQskf6m-anQ^O1yXRU8JDH8;49_WYn){nXz*D!I$fL3u`)$1#U#Ob7v6U1 EKX&OxB>(^b literal 0 HcmV?d00001 diff --git a/semilearn/algorithms/srsoftmatch/__pycache__/main.cpython-39.pyc b/semilearn/algorithms/srsoftmatch/__pycache__/main.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3049f65d7139c05688355a26f174b42558454fa0 GIT binary patch literal 2461 zcmb7GOK;pZ5GE;U_1K5kYsV@2LQn*4w-37w&|6ysFq$BRYXnK#gL)7Yk!#D)!y%#=~Ji;Q8tA-}$#7 z{ejBrV}o)NNc{$k6HYS{wr`rVkkMAQa>fWHxg9#j!m?iOhOUvV%mce`WIGG;e%Q|k z;lSulHq1xi$jCHX%g5mu3uRCoWi-t{*-FxP97f!5Q14V)Vw~OpmH=INtWHzcK@cRQoP_KZnNdzTd+CK1x48Y1g!U_9lhA=3IoiK1 z3L)dFlnGr4hVlRiL#iKvRivT~dBNsv0m-kdxeerC-~|Jzw?Lc28MiLzu;2yV=FWmG z+=fIhC$NFue)=5NI)#`-y0FgaO=P(I)YFctkA;*zXxfXSv`DKc(ySt5WWOA0=?d9$hhJ(8YWz6s{kiu{cnUjVt=rh_NU?^ZkZh(EtBrc-q6L}Th zEqenk?yQuLW69qEMd1Y^uz{2H^MyPA>mFvXh;soM45BD6d7Ys=h@yveoV6ZvCix}? z9(E~9M3zO-%Cc9$D{*a8moW%tU5^e1#*%>L!h0;YS>T6}F>iXHJZ*Zt*AK#`{U27?9yc$N=7xv5sFEp9C*Vqfw{lMll zPGfJ9Bj1dn&w3_8ud&W%xm7Ul!>mE$E^u8Bl_uypmIN=0>kH5(6OYHD3X#w}BBB z(7*y>@6t^=wkC88`WQJo8J_{EYiaT=n6VP?yfIw^2j+R%n6}Ktn_iWlfl+%V@7qOT zE@8N_RJYH+#S$SZzrfMf-8*{{8+Ev420ls^+^g+M#)X27Cr=46Ww#1?B6({{d*+rqj0V oVs61B1KV0OB!m}8j;lDS;OPGq|s<*yplYY2+wY>$E9Lru~}4A zS1YSZ2T=_#Osz)-EUf&#fSZ6X8GiAr;V1uqeb^^EpL_v^UTheC@cP)bzsTx~9Ij!& zUC6A+$jqqB$cV^@XuDc975skhM|Jm~Ru$#%sWJJ>q46O~^joNy!c<>rWpA~mF^%cI z9vCe{C7j{sf_y74;hbLximjrA^L{BXTV_yhm1Vo&SAuG*8q``f*)IB1!E|dnIMO;2 z)LV5`af|&KR@zfPQJBff&lFa6^Sk=4+L~qRJ*8QBjkPFuUaP2^s+hYMZpOFm*xCBR z_1asp*J*zmh7ZtFz2>(=?!{Zdbvtr_Ft6SEJ^*zUF)F&d0Z}Niv$K+`D(ny2jf{;C5nBjADLeBMkkfE{b>E$2Mm!7o|_!j?3*h zv&yq#E6=K|_DpFNByNgL16P#r5mpCWa!oeFW}hiwUimAE^*mXPHedM{1S37RJwi)sy_`qZDJq_Z^ zJHPU+>o>fc=efJ-a_;i?mfpU8Es3{6zI5f%M`0%leRt{Y8_P>qZnk-tbeFE)^w`qX z8=rl2>+Y)K0Da}sm8*DREGW1g$3-Kf5G`G~`ZP@VN|bB_UKDvC*KDquqgCTBlmuD{{R{^-XccH!;KCQ!8Wap#Pe~&pyk-UZ*TaA(!$8cA%D>D zLz_XC8k$9r3PKol z7_Gt3IBX=5dw|S3C~RW2@nI9!-Jna-oN%{+V{UY%qO4%|uHT5ZYz|4lrw@<>o>|$L z<@=-|_b)asgk55T?O(XqVD6@!_;Ca4xqs944?mR8HXG|aaW6Jv9Eux9jV=3;+xX6P zQk_QR?YLjk0=jPFG4#6Ob>hY&+fUqRwSk$EDyD9^gXs+~YCLissO5pRx`AzV|2G!6 zPr6|T`tI3&BM2YSnPMT_Zo7?ev(e?@hP~nWQ0_+3Wp+Gf?;VG1=I}>w2?a`~M1P4Y zR(6%XnyRUm>Z!2@%b6gACLZ|C?br7d`IV^*if6WGw~xx6xjlsySn-91vGXJRRDTkG4ETlq z@ja~5T4V~$N9j(pAk=kXjt^Xz40hA9$#9B_gp-a3(FsIOvRV{6Fo^BI79~iQ<@k0K zz5Xuv;e8io9Hu$XU|AfOlTr7k-&Bp6si)K)8E?&tMLB}(#WBox zr!8_Zx4jNtA$PV!j>I_-#x-UK!h~h912=YgB%rwvkY-L8k?U^?gAA7_Pgow2&#Z>XXZrj~6(uXCEz;%Z!&0Ptw_O-~xMdaGNvo^AL(c0)f8bhnyr~EkhaTX<-L#3#7^}IT# z&Z>DeubHZ;A5)L1JN3~mu{u18^?!WptGU1a;lKU!2d_2EyXh2Z`rt<}VLw8N;AAL} ze#KIt0g9zU3lvLB4b*x{l53O8-|v<|cy+N`tOXdHGv7+0JaU1Iqc1}o60*n|x!lU7Uq6z_2IAw*%MazV&k z@=wa0B*-MijYKgB$$?1%QSDlpb+zKBT~|~`K96*G4E_$NW!m$L1d$4f@@+eM0ELi4 zDx%8WHZ&fpyaiuNexDa;u$;R>kkmn`jl`4*ZCoIFr3iN`&5fFBc38yriIeycO(Hfp ziJEGtr&R5yhJIe%nHitNh|E5-+r-3?+5ObOacJ8(iUp+y*7wyt<)!}pzspQ7|D9Ox z!|;+h-qiQh9>lMoOY?vVX@O}AU~-!K9^SoT)=qOwp>{vdF#gxHUa4oI$Cwyd&e~ra zS?N_fk0gB+^lNF2=!0&JX7|dn$3V|iIwgBdys4*(H5T@i?_ZFjluK!y)9JJvQ^K3Z zRKCDU4LrNLs&u9=E76ZHK}6{JHcDDfD`}PGexUd0S}3N)RAYHoc%k>|vA$hOi+vN< zL6MbSXuTPUDNSHzF{e!KcsiTTu*wUgmto+4rn3y5CopuWlo>qFv;?SvQe`tQa=ke~ zHI%6b#!HPax0T*}TlsA@o@TSV>J8<~3%z65`4OydZcpnir1SkcAXQbKtAG~zGkXfP zUaAj&5YIkWch&eP)r(5+IA|@#bJPNzvp9#7YN>B-+(+wYDps#^ft4OhD|;&Orgu8k`U~hi zj`kue;#3-GqkjTQPtfTB?lhnbcY5DG7tlYCkqh94LL8y*piE))I^PC=j>&y}V?TeM zE5x7v$vvfi3gvXhoddkm9{xq#=%0Cxm9T|OUVCR`WN@4HV5JqLIR`+c*&JaAE(1L%$I3;5ER7dZ8FX#%FPZ;i%gP#qL5rl)3=d8K!L zpAOV7L(ik1?$ZKgdnG-e9${x89o9ysVG**TC0HDClA5u;1nM&2)KF zgY)yo9qFRT%h-z}{tZ|fVQji~40~%lvHT5Dcw!~~hV^&|M8XJMy8}O_#9fxixZw4F z;|Gk+ydb;A-UA^7rDbw0BTJHihcRdn)gUv#{Cy0(iIEk!16aXeXf?agpA4E)@IB-z z25z*p&`q6R!({&RRFT{C8g@YunnSU%LG(G@)2}~A<^327+;_!M7c8klc0D5Bs*P*y z%SumujthN3`CaXyac^M2={|j>32pU~P(9FIupK7}5+8mtv$9|!3lyTU!58McK5i5z z47%_CA?E)pin6mjw)D~wvvB8hSKswR1Y^jP=zlGjE zLgC-0>I99|M2^wj$>-5p2VcKKUDTkP@nPa|H?rE?W}sT(ZVWe!n%ynTG<>*z{5p{_ zkWEOs2npc}nn`C2$6;*yLwxM{7J@M^Y+i;BOu>)H+d1#*N#9>e{WA+adPaP{gHULk645w#LRvoNLI9DEfA=!E7rL)6gu(jUp>dVvBiBJjOt+GE7c>klzHQW;63q$b}UK zwzk1<0OTK2J5OE~G&JvZx=9>~0%%I0Jkn6~)<%uYIP7oZ{K3@Cykk*5?xWDSnIMgjoZUI#}67xv_#1V_W? zzn*!EoRGVsHjc!ku0cn$yj>=vHZ7<)l7!+=0;COSYWxfI^A=SUBM5VBiAam1KS!k? zII192m{v`c<0vN(Fct9(YkHqoEBdrnSL-@GMO}kaYy4+EPtL8ORn!XR)b(lIkaX&p zkykZ(ep1O9+7Qkg6|JV%Rg*$8wSqTht2M2HXO3tiY^&odvl`*d!f{F09MCtlRlSCJ z6xh^$k~el{4>Pe}GpFRkQ;B{ZRhGnn|3e;&Mt=1+d}!RFD%0M@tuVml1~`3ygO|FU zAK-EmxWWLJ2aXcFg*%*XLs95jGLjL+VZb6Uzf6d&z&Pv z73m`@g1lzhY|3CG(}zC*b^e$tx@ovH?tcsjXH9V2PgY7Ox#8k06yl(!?*J2(GTq@pa)z% zFeisO(z`#j2flgVL?B7N6?rrWA0i_v z5}(;=W!7GXfMoLdpHL;gu%FpDRgV0fwb{Y_nXz3G>Duv@Fp(+VGOz5)dA@-_$i^>2$vs01@snWZz{M1(#`%f$9erhGy`6pgZ$6;TahhzWHfRl^_MO7YicFR9w=rR- zw6ZeDm`N9Nr@XQgNEnCt?Z3ruUe@3Y4AiL$b@Jy_U4M|<3U2)3JR8gcK zTGWi0+LT$U6pO#1R4Znsq^WX5^QiQT_!FA)6qU#&aHb+o>eo6k%8RYiC9yh6j#2iE zlcy_Y4-Cm>=cw!WxN9^qT9r0toDKOIqcSq16BMTprF9Q!ySwt9lKD%Ss+7JDr7yGO ki`?Y|gQNh+V=um&9n1T)kBG7`bQj@Oms zPyAVjkY8|cvKcV=6uSHX1Sg!PWQb2KrI|L=DC&BuXU5Q=Ik6WF?Te+0`z`~7CnQKqc6W1$?zk~}FH zQ+{4$EJ{NuW2x+%$z+s;%7sa-RnbhDWg()mvdTD@g;3pXQEad*++|tCO0~2YJ`M$p zjUtCoMkq!yu$3X>bW^ynzCg$+Co8fzFRM)WXgh;8YikE*zo5!x?W&BuAb;K&q)D}x zMB_Lg4JJEJ2GQe<5W*#S{C*ZD`9Q{5($s4ZrlUe6 zGIZGkLFgH3wLjegeMwuPdzTjgcm*A6$Hx@u_~yLfX0XD8dIt4k=xyJS9^eb(uPoc_C+h-*Jx7NXLmgdB+b-X^$wW$SO2fEvX3y1nDsm+;p zNP6TAtgbuv`zQQ>f`#ADzk|pfjL^oObwT$*x5_2ZPVLiM+*{I;^DQ~9`DF;;U#R9LJ(gAlZbhjoc(1`OG&8oIJ5XfcFQF7wT-) z+RO$jrCdWC=1x0CM9v)GDeDj@zqTfBGXt?~PZ6n}{03K=A~KDsm-lc~JAkXk(r6x_ z%7-;NwOzZb@Y&$TbB&v4;d4=YJ9`9}|6ZkLNZulwa zG(ZpmIy-4GO3GslbmyU3NHokP&SA8xD6XM+58u+VO1z7Tk-<0a-N0(_sdXyXa2Hui-DSFbIum|SZZy1m!`Jz)%g52QfyZK9n>;GbLO+-SG>37&vN J0f$IC`rohqJP-f? literal 0 HcmV?d00001 diff --git a/semilearn/algorithms/srsoftmatch/srsoftmatch.py b/semilearn/algorithms/srsoftmatch/srsoftmatch.py new file mode 100644 index 0000000..bf9ea76 --- /dev/null +++ b/semilearn/algorithms/srsoftmatch/srsoftmatch.py @@ -0,0 +1,164 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +import torch +import numpy as np +from .utils import SoftMatchWeightingHook +from semilearn.core.algorithmbase import AlgorithmBase +from semilearn.core.utils import ALGORITHMS +from semilearn.algorithms.hooks import PseudoLabelingHook, DistAlignEMAHook +from semilearn.algorithms.utils import SSL_Argument, str2bool +from semilearn.algorithms.SemiReward import Rewarder,Generator,SemiReward_infer,SemiReward_train + +@ALGORITHMS.register('srsoftmatch') +class SoftMatch(AlgorithmBase): + """ + SoftMatch algorithm (https://openreview.net/forum?id=ymt1zQXBDiF&referrer=%5BAuthor%20Console%5D(%2Fgroup%3Fid%3DICLR.cc%2F2023%2FConference%2FAuthors%23your-submissions)). + + Args: + - args (`argparse`): + algorithm arguments + - net_builder (`callable`): + network loading function + - tb_log (`TBLog`): + tensorboard logger + - logger (`logging.Logger`): + logger to use + - T (`float`): + Temperature for pseudo-label sharpening + - hard_label (`bool`, *optional*, default to `False`): + If True, targets have [Batch size] shape with int values. If False, the target is vector + - ema_p (`float`): + exponential moving average of probability update + """ + def __init__(self, args, net_builder, tb_log=None, logger=None): + super().__init__(args, net_builder, tb_log, logger) + self.init(T=args.T, hard_label=args.hard_label, dist_align=args.dist_align, dist_uniform=args.dist_uniform, ema_p=args.ema_p, n_sigma=args.n_sigma, per_class=args.per_class) + self.it=0 + self.rewarder = Rewarder(128,384).cuda(self.gpu) + self.generator = Generator(384).cuda(self.gpu) + self.starttiming=20000 + self.rewarder_optimizer = torch.optim.Adam(self.rewarder.parameters(), lr=0.0005) + self.generator_optimizer = torch.optim.Adam(self.generator.parameters(), lr=0.0005) + self.criterion = torch.nn.MSELoss() + + self.semi_reward_infer = SemiReward_infer(self.rewarder, self.starttiming) + self.semi_reward_train = SemiReward_train(self.rewarder, self.generator, self.criterion, self.starttiming,self.gpu) + def init(self, T, hard_label=True, dist_align=True, dist_uniform=True, ema_p=0.999, n_sigma=2, per_class=False): + self.T = T + self.use_hard_label = hard_label + self.dist_align = dist_align + self.dist_uniform = dist_uniform + self.ema_p = ema_p + self.n_sigma = n_sigma + self.per_class = per_class + + def set_hooks(self): + self.register_hook(PseudoLabelingHook(), "PseudoLabelingHook") + self.register_hook( + DistAlignEMAHook(num_classes=self.num_classes, momentum=self.args.ema_p, p_target_type='uniform' if self.args.dist_uniform else 'model'), + "DistAlignHook") + self.register_hook(SoftMatchWeightingHook(num_classes=self.num_classes, n_sigma=self.args.n_sigma, momentum=self.args.ema_p, per_class=self.args.per_class), "MaskingHook") + super().set_hooks() + + def train_step(self, x_lb, y_lb, x_ulb_w, x_ulb_s): + num_lb = y_lb.shape[0] + + # inference and calculate sup/unsup losses + with self.amp_cm(): + if self.use_cat: + inputs = torch.cat((x_lb, x_ulb_w, x_ulb_s)) + outputs = self.model(inputs) + logits_x_lb = outputs['logits'][:num_lb] + logits_x_ulb_w, logits_x_ulb_s = outputs['logits'][num_lb:].chunk(2) + feats_x_lb = outputs['feat'][:num_lb] + feats_x_ulb_w, feats_x_ulb_s = outputs['feat'][num_lb:].chunk(2) + else: + outs_x_lb = self.model(x_lb) + logits_x_lb = outs_x_lb['logits'] + feats_x_lb = outs_x_lb['feat'] + outs_x_ulb_s = self.model(x_ulb_s) + logits_x_ulb_s = outs_x_ulb_s['logits'] + feats_x_ulb_s = outs_x_ulb_s['feat'] + with torch.no_grad(): + outs_x_ulb_w = self.model(x_ulb_w) + logits_x_ulb_w = outs_x_ulb_w['logits'] + feats_x_ulb_w = outs_x_ulb_w['feat'] + feat_dict = {'x_lb':feats_x_lb, 'x_ulb_w':feats_x_ulb_w, 'x_ulb_s':feats_x_ulb_s} + + + sup_loss = self.ce_loss(logits_x_lb, y_lb, reduction='mean') + + probs_x_lb = torch.softmax(logits_x_lb.detach(), dim=-1) + probs_x_ulb_w = torch.softmax(logits_x_ulb_w.detach(), dim=-1) + + # uniform distribution alignment + probs_x_ulb_w = self.call_hook("dist_align", "DistAlignHook", probs_x_ulb=probs_x_ulb_w, probs_x_lb=probs_x_lb) + + # calculate weight + mask = self.call_hook("masking", "MaskingHook", logits_x_ulb=probs_x_ulb_w, softmax_x_ulb=False) + + # generate unlabeled targets using pseudo label hook + pseudo_label = self.call_hook("gen_ulb_targets", "PseudoLabelingHook", + # make sure this is logits, not dist aligned probs + # uniform alignment in softmatch do not use aligned probs for generating pesudo labels + logits=logits_x_ulb_w, + use_hard_label=self.use_hard_label, + T=self.T) + filtered_pseudo_labels = self.semi_reward_infer(feats_x_ulb_w, pseudo_label, self.it) + for filtered_pseudo_label in filtered_pseudo_labels: + unsup_loss = self.consistency_loss(logits_x_ulb_s, + filtered_pseudo_label, + 'ce', + mask=mask) + + if self.it > 0: + generator_loss, rewarder_loss = self.semi_reward_train(feats_x_ulb_w, pseudo_label, y_lb, self.it) + self.generator_optimizer.zero_grad() + self.rewarder_optimizer.zero_grad() + generator_loss.backward(retain_graph=True) + rewarder_loss.backward(retain_graph=True) + self.generator_optimizer.step() + self.rewarder_optimizer.step() + + total_loss = sup_loss + self.lambda_u * unsup_loss + + + out_dict = self.process_out_dict(loss=total_loss, feat=feat_dict) + log_dict = self.process_log_dict(sup_loss=sup_loss.item(), + unsup_loss=unsup_loss.item(), + total_loss=total_loss.item(), + util_ratio=mask.float().mean().item()) + return out_dict, log_dict + + # TODO: change these + def get_save_dict(self): + save_dict = super().get_save_dict() + # additional saving arguments + save_dict['p_model'] = self.hooks_dict['DistAlignHook'].p_model.cpu() + save_dict['p_target'] = self.hooks_dict['DistAlignHook'].p_target.cpu() + save_dict['prob_max_mu_t'] = self.hooks_dict['MaskingHook'].prob_max_mu_t.cpu() + save_dict['prob_max_var_t'] = self.hooks_dict['MaskingHook'].prob_max_var_t.cpu() + return save_dict + + + def load_model(self, load_path): + checkpoint = super().load_model(load_path) + self.hooks_dict['DistAlignHook'].p_model = checkpoint['p_model'].cuda(self.args.gpu) + self.hooks_dict['DistAlignHook'].p_target = checkpoint['p_target'].cuda(self.args.gpu) + self.hooks_dict['MaskingHook'].prob_max_mu_t = checkpoint['prob_max_mu_t'].cuda(self.args.gpu) + self.hooks_dict['MaskingHook'].prob_max_var_t = checkpoint['prob_max_var_t'].cuda(self.args.gpu) + self.print_fn("additional parameter loaded") + return checkpoint + + @staticmethod + def get_argument(): + return [ + SSL_Argument('--hard_label', str2bool, True), + SSL_Argument('--T', float, 0.5), + SSL_Argument('--dist_align', str2bool, True), + SSL_Argument('--dist_uniform', str2bool, True), + SSL_Argument('--ema_p', float, 0.999), + SSL_Argument('--n_sigma', int, 2), + SSL_Argument('--per_class', str2bool, False), + ] diff --git a/semilearn/algorithms/srsoftmatch/utils.py b/semilearn/algorithms/srsoftmatch/utils.py new file mode 100644 index 0000000..0e32604 --- /dev/null +++ b/semilearn/algorithms/srsoftmatch/utils.py @@ -0,0 +1,77 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + + + +import torch + +from semilearn.algorithms.utils import concat_all_gather +from semilearn.algorithms.hooks import MaskingHook + + +class SoftMatchWeightingHook(MaskingHook): + """ + SoftMatch learnable truncated Gaussian weighting + """ + def __init__(self, num_classes, n_sigma=2, momentum=0.999, per_class=False, *args, **kwargs): + super().__init__(*args, **kwargs) + self.num_classes = num_classes + self.n_sigma = n_sigma + self.per_class = per_class + self.m = momentum + + # initialize Gaussian mean and variance + if not self.per_class: + self.prob_max_mu_t = torch.tensor(1.0 / self.num_classes) + self.prob_max_var_t = torch.tensor(1.0) + else: + self.prob_max_mu_t = torch.ones((self.num_classes)) / self.args.num_classes + self.prob_max_var_t = torch.ones((self.num_classes)) + + @torch.no_grad() + def update(self, algorithm, probs_x_ulb): + if algorithm.distributed and algorithm.world_size > 1: + probs_x_ulb = self.concat_all_gather(probs_x_ulb) + max_probs, max_idx = probs_x_ulb.max(dim=-1) + if not self.per_class: + prob_max_mu_t = torch.mean(max_probs) # torch.quantile(max_probs, 0.5) + prob_max_var_t = torch.var(max_probs, unbiased=True) + self.prob_max_mu_t = self.m * self.prob_max_mu_t + (1 - self.m) * prob_max_mu_t.item() + self.prob_max_var_t = self.m * self.prob_max_var_t + (1 - self.m) * prob_max_var_t.item() + else: + prob_max_mu_t = torch.zeros_like(self.prob_max_mu_t) + prob_max_var_t = torch.ones_like(self.prob_max_var_t) + for i in range(self.num_classes): + prob = max_probs[max_idx == i] + if len(prob) > 1: + prob_max_mu_t[i] = torch.mean(prob) + prob_max_var_t[i] = torch.var(prob, unbiased=True) + self.prob_max_mu_t = self.m * self.prob_max_mu_t + (1 - self.m) * prob_max_mu_t + self.prob_max_var_t = self.m * self.prob_max_var_t + (1 - self.m) * prob_max_var_t + return max_probs, max_idx + + @torch.no_grad() + def masking(self, algorithm, logits_x_ulb, softmax_x_ulb=True, *args, **kwargs): + if not self.prob_max_mu_t.is_cuda: + self.prob_max_mu_t = self.prob_max_mu_t.to(logits_x_ulb.device) + if not self.prob_max_var_t.is_cuda: + self.prob_max_var_t = self.prob_max_var_t.to(logits_x_ulb.device) + + if softmax_x_ulb: + probs_x_ulb = torch.softmax(logits_x_ulb.detach(), dim=-1) + else: + # logits is already probs + probs_x_ulb = logits_x_ulb.detach() + + self.update(algorithm, probs_x_ulb) + + max_probs, max_idx = probs_x_ulb.max(dim=-1) + # compute weight + if not self.per_class: + mu = self.prob_max_mu_t + var = self.prob_max_var_t + else: + mu = self.prob_max_mu_t[max_idx] + var = self.prob_max_var_t[max_idx] + mask = torch.exp(-((torch.clamp(max_probs - mu, max=0.0) ** 2) / (2 * var / (self.n_sigma ** 2)))) + return mask \ No newline at end of file