-
Notifications
You must be signed in to change notification settings - Fork 0
/
train.py
115 lines (104 loc) · 3.83 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
import logging
import torch
from torch.autograd import Variable
import adversary as adv
from dataloader import DataDistributor
from options import ARGS
from utils import CUDA, calc_accuracy, writer
from worker import Role, Worker
# compute the number of good workers in this system
ARGS.num_v_workers_sim = ARGS.num_workers_sim - ARGS.num_b_workers_sim
ARGS.cuda = not ARGS.no_cuda and torch.cuda.is_available()
ARGS.alpha = 0.5
ARGS.truncated_bptt_step = 5
data_distributor = DataDistributor(
ARGS.dataset_path, ARGS.dataset, ARGS.batch_size, ARGS.num_workers_sim
)
data_distributor.distribute()
train_loaders = data_distributor.train_loaders
test_loader = data_distributor.test_loader
attack_method = adv.attack_methods[ARGS.attack_method]().attack
def GAA():
master = Worker(
-1,
None,
None,
neighbors_n=ARGS.num_workers_sim,
train_loader=train_loaders[0],
test_loader=test_loader,
meta_lr=1e-2,
policy_lr=1e-2,
dataset=ARGS.dataset,
missing_labels=None,
role=Role.NORMAL,
period=2e8,
alpha=0.5,
extreme_mail=None,
pretense=1e8,
)
workers = []
for i in range(ARGS.num_workers_sim):
byzantine = True if i < ARGS.num_b_workers_sim else False
worker = Worker(
wid=i,
atk_fn=attack_method if byzantine else None,
adv_loss=ARGS.adv_loss if byzantine else None,
neighbors_n=ARGS.num_workers_sim,
train_loader=train_loaders[i],
test_loader=test_loader,
meta_lr=1e-2,
policy_lr=1e-2,
dataset=ARGS.dataset,
missing_labels=None,
role=Role.TRADITIONAL_ATTACK if byzantine else Role.NORMAL,
period=2e8,
alpha=0.5,
extreme_mail=None,
pretense=1e8,
)
workers.append(worker)
alpha_iter_no = 0
for i in range(ARGS.max_iter):
test_accuracy = calc_accuracy(master.meta_model, test_loader)
writer.add_scalar("data/test_accuracy", test_accuracy, alpha_iter_no)
inner_step_count = ARGS.optimizer_steps // ARGS.truncated_bptt_step
for k in range(inner_step_count):
loss_sum = 0
prev_loss = CUDA(torch.zeros(1))
master.reset()
for t in range(ARGS.truncated_bptt_step):
alpha_iter_no += 1
Q = []
for worker in workers:
# send current parameters \theta_t to each worker
worker.copy_meta_params_from(master.meta_model)
# receive submitted gradients Q_t
Q.append(worker.submit(alpha_iter_no))
# update alpha
loss = master.alpha_update(Q)
# update GAR \theta using GAR
master.meta_update(Q)
# calc l_{GAA}
if t > 0:
loss_sum += loss - Variable(prev_loss)
prev_loss = loss.data
writer.add_scalars(
"data/alpha",
{
"weight_{0}".format(i): master.alpha[i].data
for i in range(len(workers))
},
alpha_iter_no,
)
# update policy model
master.policy_update(loss_sum)
# test accuracy
if alpha_iter_no % 100 == 0:
test_acc = calc_accuracy(master.meta_model, master.test_loader)
logging.info("Test Set Accuracy: {0}".format(test_acc))
logging.info("Alpha: {0}".format(master.alpha.data))
writer.add_scalar("data/test_accuracy", test_acc, alpha_iter_no)
if __name__ == "__main__":
logging.basicConfig(level=logging.INFO)
torch.autograd.set_detect_anomaly(True)
GAA()