Skip to content

Commit

Permalink
✨ feat(Big): Multi updates
Browse files Browse the repository at this point in the history
Refactor codes about setting experiment arguments and how FL method classes using them.
Now all arguments are mostly set via config file and the way of setting common experiment arguments via CLI is no longer supported. Now users can only through CLI setting FL methods arguements, which is still prior to set in config file and the defaults.

Beutify the stdout of argument dict on terminal.

Rename some variables.

Add more type hints.

Make README.md more readable.

Try to make codes more elegant.
  • Loading branch information
KarhouTam committed Apr 4, 2024
1 parent e28f200 commit 89dc572
Show file tree
Hide file tree
Showing 71 changed files with 1,283 additions and 1,066 deletions.
167 changes: 115 additions & 52 deletions README.md

Large diffs are not rendered by default.

32 changes: 0 additions & 32 deletions config/template.yaml

This file was deleted.

51 changes: 51 additions & 0 deletions config/template.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
# Full explaination are listed on README.md
common:
dataset: mnist
seed: 42
model: lenet5
join_ratio: 0.1
global_epoch: 100
local_epoch: 5
finetune_epoch: 0
batch_size: 32
test_interval: 100
straggler_ratio: 0
straggler_min_local_epoch: 0
external_model_params_file: null
optimizer:
name: sgd # [sgd, adam, adamw, rmsprop, adagrad]
lr: 0.01
dampening: 0 # [sgd]
weight_decay: 0
momentum: 0 # SGD, RMSprop
alpha: 0.99 # RMSprop
nesterov: false # SGD
betas: [0.9, 0.999] # Adam, AdamW

eval_test: true
eval_val: false
eval_train: false

verbose_gap: 10
visible: false
use_cuda: true
save_log: true
save_model: false
save_fig: true
save_metrics: true
viz_win_name: null
check_convergence: true


# You can set specific arguments for FL methods also
# FL-bench uses FL method arguments by args.<method>.<arg>
# e.g.,
# fedprox:
# mu: 0.01
#
# pfedsim:
# warmup_round: 0.7
# ...


# NOTE: For those unmentioned arguments, the default values are set in `get_<method>_args()` in `src/server/<method>.py`
1 change: 0 additions & 1 deletion data/utils/process.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
def prune_args(args: Namespace) -> Dict:
args_dict = {}
# general settings
args_dict["dataset"] = args.dataset
args_dict["client_num"] = args.client_num
args_dict["test_ratio"] = args.test_ratio
args_dict["val_ratio"] = args.val_ratio
Expand Down
2 changes: 1 addition & 1 deletion generate_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def main(args):
partition = {"separation": None, "data_indices": [[] for _ in range(client_num)]}
stats = {}
dataset: BaseDataset = None

if args.dataset == "femnist":
dataset = process_femnist(args, partition, stats)
elif args.dataset == "celeba":
Expand Down
42 changes: 29 additions & 13 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,31 +3,47 @@
import inspect
from pathlib import Path

sys.path.append(Path(__file__).parent.joinpath("src/server").absolute().as_posix())
FLBENCH_ROOT = Path(__file__).parent.absolute()
if FLBENCH_ROOT not in sys.path:
sys.path.append(FLBENCH_ROOT.as_posix())

SERVER_DIR = Path(__file__).parent.joinpath("src/server").absolute()
if SERVER_DIR not in sys.path:
sys.path.append(SERVER_DIR.as_posix())

from src.utils.tools import parse_args


if __name__ == "__main__":
if len(sys.argv) < 2:
if len(sys.argv) < 3:
raise ValueError(
"Need to assign a method. Run like `python main.py <method> [args ...]`, e.g., python main.py fedavg -d cifar10 -m lenet5`"
"No <method> or <config_file>. Run like `python main.py <method> <config_file_relative_path> [cli_method_args ...]`,\n e.g., python main.py fedavg config/template.yml`"
)

method = sys.argv[1]
args_list = sys.argv[2:]
method_name = sys.argv[1]
config_file_path = sys.argv[2]
cli_method_args = sys.argv[3:]
try:
method_module = importlib.import_module(method_name)
except:
raise FileNotFoundError(f"unrecongnized method: {method_name}.")

module = importlib.import_module(method)
try:
get_argparser = getattr(module, f"get_{method}_argparser")
get_method_args_func = getattr(method_module, f"get_{method_name}_args")
except:
fedavg_module = importlib.import_module("fedavg")
get_argparser = getattr(fedavg_module, "get_fedavg_argparser")
parser = get_argparser()
module_attributes = inspect.getmembers(module, inspect.isclass)
get_method_args_func = None

module_attributes = inspect.getmembers(method_module, inspect.isclass)
server_class = [
attribute
for attribute in module_attributes
if attribute[0].lower() == method + "server"
if attribute[0].lower() == method_name + "server"
][0][1]

server = server_class(args=parser.parse_args(args_list))
server = server_class(
args=parse_args(
config_file_path, method_name, get_method_args_func, cli_method_args
)
)

server.run()
15 changes: 12 additions & 3 deletions src/client/adcol.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,20 @@
import torch.nn as nn

from fedavg import FedAvgClient
from src.utils.tools import trainable_params
from src.utils.tools import Logger, NestedNamespace, trainable_params
from src.utils.models import DecoupledModel


class ADCOLClient(FedAvgClient):
def __init__(self, model, discriminator, args, logger, device, client_num):
def __init__(
self,
model: DecoupledModel,
discriminator: torch.nn.Module,
args: NestedNamespace,
logger: Logger,
device: torch.device,
client_num: int,
):
super(ADCOLClient, self).__init__(model, args, logger, device)
self.discriminator = discriminator
self.discriminator.to(self.device)
Expand Down Expand Up @@ -44,7 +53,7 @@ def fit(self):
target_index_softmax = F.softmax(target_index, dim=-1)
kl_loss_func = nn.KLDivLoss(reduction="batchmean").to(self.device)
kl_loss = kl_loss_func(client_index_softmax, target_index_softmax)
mu = self.args.mu
mu = self.args.adcol.mu

loss = cross_entropy + mu * kl_loss
self.optimizer.zero_grad()
Expand Down
26 changes: 19 additions & 7 deletions src/client/apfl.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,17 +4,26 @@
import torch

from fedavg import FedAvgClient
from src.utils.tools import trainable_params
from src.utils.tools import Logger, trainable_params, NestedNamespace
from src.utils.models import DecoupledModel


class APFLClient(FedAvgClient):
def __init__(self, model, args, logger, device, client_num):
def __init__(
self,
model: DecoupledModel,
args: NestedNamespace,
logger: Logger,
device: torch.device,
client_num: int,
):
super().__init__(model, args, logger, device)

self.alpha_list = [
torch.tensor(self.args.alpha, device=self.device) for _ in range(client_num)
torch.tensor(self.args.apfl.alpha, device=self.device)
for _ in range(client_num)
]
self.alpha = torch.tensor(self.args.alpha, device=self.device)
self.alpha = torch.tensor(self.args.apfl.alpha, device=self.device)

self.local_model = deepcopy(self.model)

Expand All @@ -36,7 +45,10 @@ def re_init(src):
}

self.optimizer.add_param_group(
{"params": trainable_params(self.local_model), "lr": self.args.local_lr}
{
"params": trainable_params(self.local_model),
"lr": self.args.common.optimizer.lr,
}
)
self.init_opt_state_dict = deepcopy(self.optimizer.state_dict())

Expand Down Expand Up @@ -73,7 +85,7 @@ def fit(self):
loss.backward()
self.optimizer.step()

if self.args.adaptive_alpha and i == 0:
if self.args.apfl.adaptive_alpha and i == 0:
self.update_alpha()

# refers to https://github.com/MLOPTPSU/FedTorch/blob/b58da7408d783fd426872b63fbe0c0352c7fa8e4/fedtorch/comms/utils/flow_utils.py#L240
Expand All @@ -90,7 +102,7 @@ def update_alpha(self):
alpha_grad += diff @ grad

alpha_grad += 0.02 * self.alpha
self.alpha.data -= self.args.local_lr * alpha_grad
self.alpha.data -= self.args.common.optimizer.lr * alpha_grad
self.alpha.clip_(0, 1.0)

def evaluate(self):
Expand Down
23 changes: 17 additions & 6 deletions src/client/ditto.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,32 @@
from copy import deepcopy
from typing import Dict, OrderedDict
from typing import OrderedDict

import torch

from fedavg import FedAvgClient
from src.utils.tools import trainable_params
from src.utils.tools import Logger, NestedNamespace, trainable_params
from src.utils.models import DecoupledModel


class DittoClient(FedAvgClient):
def __init__(self, model, args, logger, device, client_num):
def __init__(
self,
model: DecoupledModel,
args: NestedNamespace,
logger: Logger,
device: torch.device,
client_num: int,
):
super().__init__(model, args, logger, device)
self.pers_model = deepcopy(model)
self.pers_model_params_dict = {
cid: deepcopy(self.pers_model.state_dict()) for cid in range(client_num)
}
self.optimizer.add_param_group(
{"params": trainable_params(self.pers_model), "lr": self.args.local_lr}
{
"params": trainable_params(self.pers_model),
"lr": self.args.common.optimizer.lr,
}
)
self.init_opt_state_dict = deepcopy(self.optimizer.state_dict())

Expand Down Expand Up @@ -45,7 +56,7 @@ def fit(self):
loss.backward()
self.optimizer.step()

for _ in range(self.args.pers_epoch):
for _ in range(self.args.ditto.pers_epoch):
for x, y in self.trainloader:
x, y = x.to(self.device), y.to(self.device)
logit = self.pers_model(x)
Expand All @@ -56,7 +67,7 @@ def fit(self):
trainable_params(self.pers_model),
trainable_params(self.global_params),
):
pers_param.grad.data += self.args.lamda * (
pers_param.grad.data += self.args.ditto.lamda * (
pers_param.data - global_param.data
)
self.optimizer.step()
Expand Down
5 changes: 3 additions & 2 deletions src/client/elastic.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from fedavg import FedAvgClient
from src.utils.models import DecoupledModel
from src.utils.tools import Logger, trainable_params
from src.utils.models import DecoupledModel


class ElasticClient(FedAvgClient):
Expand Down Expand Up @@ -52,8 +53,8 @@ def train(
]
for i in range(len(grads_norm)):
self.sensitivity[self.client_id][i] = (
self.args.mu * self.sensitivity[self.client_id][i]
+ (1 - self.args.mu) * grads_norm[i].abs()
self.args.elastic.mu * self.sensitivity[self.client_id][i]
+ (1 - self.args.elastic.mu) * grads_norm[i].abs()
)

eval_results = self.train_and_log(verbose=verbose)
Expand Down
17 changes: 13 additions & 4 deletions src/client/fedap.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,18 +3,27 @@
import torch

from fedbn import FedBNClient
from src.utils.models import DecoupledModel
from src.utils.tools import Logger, NestedNamespace


class FedAPClient(FedBNClient):
def __init__(self, model, args, logger, device):
def __init__(
self,
model: DecoupledModel,
args: NestedNamespace,
logger: Logger,
device: torch.device,
):
super(FedAPClient, self).__init__(model, args, logger, device)

self.model.need_all_features()
self.pretrain = False

def load_dataset(self):
super().load_dataset()
num_pretrain_samples = int(self.args.pretrain_ratio * len(self.trainset))
if self.args.version != "f":
num_pretrain_samples = int(self.args.fedap.pretrain_ratio * len(self.trainset))
if self.args.fedap.version != "f":
if self.pretrain:
self.trainset.indices = self.trainset.indices[:num_pretrain_samples]
else:
Expand All @@ -40,7 +49,7 @@ def get_all_features(

self.save_state()

if self.args.version == "d":
if self.args.fedap.version == "d":
for i, features in enumerate(features_list):
for j in range(len(features)):
if len(features[j].shape) == 4 and len(features[j + 1].shape) < 4:
Expand Down
Loading

0 comments on commit 89dc572

Please sign in to comment.