Skip to content

Commit

Permalink
Update srfreematch.py
Browse files Browse the repository at this point in the history
  • Loading branch information
WayneJin0918 authored Oct 5, 2023
1 parent ae95ab0 commit 160a5c8
Showing 1 changed file with 6 additions and 6 deletions.
12 changes: 6 additions & 6 deletions semilearn/algorithms/srfreematch/srfreematch.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,11 +51,11 @@ def __init__(self, args, net_builder, tb_log=None, logger=None):
self.init(T=args.T, hard_label=args.hard_label, ema_p=args.ema_p, use_quantile=args.use_quantile, clip_thresh=args.clip_thresh)
self.lambda_e = args.ent_loss_ratio
self.it=0
self.rewarder = Rewarder(128,384).cuda(self.gpu)
self.generator = Generator(384).cuda(self.gpu)
self.starttiming=20000
self.rewarder_optimizer = torch.optim.Adam(self.rewarder.parameters(), lr=0.0005)
self.generator_optimizer = torch.optim.Adam(self.generator.parameters(), lr=0.0005)
self.rewarder = Rewarder(128,self.featinput).cuda(self.gpu)
self.generator = Generator(self.featinput).cuda(self.gpu)
self.starttiming=self.start
self.rewarder_optimizer = torch.optim.Adam(self.rewarder.parameters(), lr=self.srlr)
self.generator_optimizer = torch.optim.Adam(self.generator.parameters(), lr=self.srlr)
self.criterion = torch.nn.MSELoss()

self.semi_reward_infer = SemiReward_infer(self.rewarder, self.starttiming)
Expand Down Expand Up @@ -167,4 +167,4 @@ def get_argument():
SSL_Argument('--ent_loss_ratio', float, 0.01),
SSL_Argument('--use_quantile', str2bool, False),
SSL_Argument('--clip_thresh', str2bool, False),
]
]

0 comments on commit 160a5c8

Please sign in to comment.