-
Notifications
You must be signed in to change notification settings - Fork 1
/
create_train.py
34 lines (29 loc) · 918 Bytes
/
create_train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
#!/usr/bin/env python
# -*-coding:utf-8 -*-
'''
@File : create_train.py
@Time : 2022/10/10 14:11:23
@Author : Bo
'''
import train_fed_avg as train_fed_avg
import train_scaffold as train_scaffold
import train_feddyn as train_feddyn
import train_fedprox as train_fedprox
import configs.conf as const
import torch
device = torch.device("cuda")
def run(conf):
if conf.aggregation == "fed_avg":
train_fed_avg.train_with_conf(conf)
elif conf.aggregation == "scaffold":
train_scaffold.train_with_conf(conf)
elif conf.aggregation == "fed_pvr":
train_scaffold.train_with_conf(conf)
elif conf.aggregation == "fed_dyn":
train_feddyn.train_with_conf(conf)
elif conf.aggregation == "fed_prox":
train_fedprox.train_with_conf(conf)
if __name__ == "__main__":
a = torch.zeros([1]).to(device)
conf = const.give_fed_args()
run(conf)