-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathadversarial_optim.py
127 lines (104 loc) · 5.33 KB
/
adversarial_optim.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
import torch
import torch.optim as optim
import copy
import numpy as np
import math
from contextlib import contextmanager
######################################################
# ADVERSARIAL OPTIMIZER #
######################################################
class AdversarialWrapper(optim.Optimizer):
def __init__(self, task_optim, adversary_optim, eta=1):
if eta < 1:
raise ValueError("Invalid eta: {}".format(eta))
params = task_optim.param_groups + adversary_optim.param_groups
# Adds task and adversary parameters to self.param_groups and ensures they don't overlap
# - self.param_groups is a list of dicts, wherein params are stored under the key 'params'
# - Example: param_list = [p for g in self.param_groups for p in g['params']]
super(AdversarialWrapper, self).__init__(params, defaults=dict())
# Keep copies of each parameter, and the last step taken
self._copy_params = copy.deepcopy(self.param_groups)
self._copy_params = [p for g in self._copy_params for p in g['params']] # Listify
self._last_diff = copy.deepcopy(self.param_groups)
self._last_diff = [p for g in self._last_diff for p in g['params']]
self.reset() # Zero out last_diff
# Task and adversary optimizers
self._task_optim = task_optim
self._adv_optim = adversary_optim
# Number of adversary steps per task step
self._eta = eta
# Internal step count since last task step
self._steps_since_task = self._eta-1
# Training mode: 'train' | 'task' | 'adversary'
self._mode = 'train'
# Clipping threshold for gradient norms
self.max_grad_norm = 1.0
# Update our internal copies of the parameters. Must be called after every parameter
# change if gradient prediction is being used; can be ignored otherwise
def update(self):
for i, param in enumerate([p for g in self.param_groups for p in g['params']]):
self._last_diff[i].data[:] = param.data[:] - self._copy_params[i].data[:]
self._copy_params[i].data[:] = param.data[:]
# Reset our internal copies of the parameters. Useful when switching between pretraining and
# training, or when turning prediction on/off
def reset(self):
for i, param in enumerate([p for g in self.param_groups for p in g['params']]):
self._copy_params[i].data[:] = param.data[:]
self._last_diff[i].data[:] = 0.0
# Update the task parameters (featurizer and classifier) by calling the task optimizer's step()
def step_task(self, update_after=True, **kwargs):
self._task_optim.step(**kwargs)
self._steps_since_task = 0
if update_after:
self.update()
# Update the adversary parameters (discriminator) by calling the adversary optimizer's step()
def step_adversary(self, update_after=True, **kwargs):
self._adv_optim.step(**kwargs)
self._steps_since_task += 1
if update_after:
self.update()
# The Optimizer class method. Alternates between 1 task step and [eta] adversary steps
def step(self, update_after=True, **kwargs):
if self._mode == 'train':
self.clip_grads()
if self.step_type() == 'task':
self.step_task(update_after, **kwargs)
else:
self.step_adversary(update_after, **kwargs)
# Return the type ('task' or 'adversary') of the next step that will be taken
def step_type(self):
if self._mode != 'train':
return self._mode
if self._steps_since_task >= self._eta:
return 'task'
return 'adversary'
# Clip all parameter gradients to avoid NaN explosion
def clip_grads(self):
for param in [p for g in self.param_groups for p in g['params'] if p.grad is not None]:
param.grad[torch.isnan(param.grad)] = 0.0
torch.nn.utils.clip_grad_norm_(param, self.max_grad_norm)
# Change the optimization mode:
# 'train': Alternate between task and adversary steps
# 'task': Only take task steps (for pretraining)
# 'adversary': Only take adversary steps (for pretraining)
def mode(self, m):
assert m in ['train', 'task', 'adversary'], 'Invalid AdversarialWrapper mode: {}'.format(m)
self._mode = m
# Look ahead in parameter space to compute gradients at a predicted point
@contextmanager
def lookahead(self, step=1.0):
# If step is 0.0, do nothing
if step == 0.0:
yield
return
# Otherwise, step each parameter forward
param_list = [p for g in self.param_groups for p in g['params']]
for i, p in enumerate(param_list):
# Integrity check
if torch.sum(p.data[:] != self._copy_params[i].data[:]) > 0:
raise RuntimeWarning("Stored parameters differ from current ones. Use step(update=True) when taking an optimization step, or manually call update() after each modification to the network parameters.")
p.data[:] += step * self._last_diff[i].data[:]
yield
# Roll back to original values
for i, p in enumerate(param_list):
p.data[:] = self._copy_params[i].data[i]