diff --git a/.github/workflows/docker-publish.yml b/.github/workflows/docker-publish.yml index 7185112..fee5db6 100644 --- a/.github/workflows/docker-publish.yml +++ b/.github/workflows/docker-publish.yml @@ -15,6 +15,7 @@ env: # Use docker.io for Docker Hub if empty GHCR_REGISTRY: ghcr.io ALIYUN_REGISTRY: registry.cn-hangzhou.aliyuncs.com + DOCKERHUB_REGISTRY: docker.io # github.repository as / IMAGE_NAME: ${{ github.repository }} @@ -99,4 +100,34 @@ jobs: tags: ${{ steps.meta-aliyun.outputs.tags }} labels: ${{ steps.meta-aliyun.outputs.labels }} cache-from: type=gha + cache-to: type=gha,mode=max + + # Login against a Docker registry except on PR + # https://github.com/docker/login-action + - name: Log into registry ${{ env.DOCKERHUB_REGISTRY }} + if: github.event_name != 'pull_request' + uses: docker/login-action@343f7c4344506bcbf9b4de18042ae17996df046d # v3.0.0 + with: + registry: ${{ env.DOCKERHUB_REGISTRY }} + username: ${{ secrets.DOCKERHUB_USERNAME }} + password: ${{ secrets.DOCKERHUB_TOKEN }} + # Extract metadata (tags, labels) for Docker + # https://github.com/docker/metadata-action + - name: Extract Docker metadata + id: meta-dockerhub + uses: docker/metadata-action@96383f45573cb7f253c731d3b3ab81c87ef81934 # v5.0.0 + with: + images: ${{ env.DOCKERHUB_REGISTRY }}/${{ env.IMAGE_NAME }} + + # Build and push Docker image with Buildx (don't push on PR) + # https://github.com/docker/build-push-action + - name: Build and push Docker image + id: build-and-push-to-Dockerhub-container-registry + uses: docker/build-push-action@0565240e2d4ab88bba5387d719585280857ece09 # v5.0.0 + with: + context: . + push: ${{ github.event_name != 'pull_request' }} + tags: ${{ steps.meta-Dockerhub.outputs.tags }} + labels: ${{ steps.meta-Dockerhub.outputs.labels }} + cache-from: type=gha cache-to: type=gha,mode=max \ No newline at end of file diff --git a/CITATION.cff b/CITATION.cff new file mode 100755 index 0000000..d3202ec --- /dev/null +++ b/CITATION.cff @@ -0,0 +1,25 @@ +# This CITATION.cff file was generated with cffinit. +# Visit https://bit.ly/cffinit to generate yours today! + +cff-version: 1.2.0 +title: 'FL-bench: A federated learning benchmark for solving image classification tasks' +message: >- + If you use this software, please cite it using the + metadata from this file. +type: software +authors: + - given-names: Jiahao + family-names: Tan + email: karhoutam@qq.com + affiliation: Shenzhen University + - given-names: Xinpeng + family-names: Wang + affiliation: 'The Chinese University of Hong Kong, Shenzhen' + email: 223015056@link.cuhk.edu.cn +repository-code: 'https://github.com/KarhouTam/FL-bench' +abstract: >- + Benchmark of federated learning that aim solving image + classification tasks. +keywords: + - federated learning +license: GPL-2.0 diff --git a/README.md b/README.md index 8da24ed..3fdc593 100644 --- a/README.md +++ b/README.md @@ -92,6 +92,8 @@ - ***MetaFed*** -- [MetaFed: Federated Learning among Federations with Cyclic Knowledge Distillation for Personalized Healthcare](http://arxiv.org/abs/2206.08516) (IJCAI'22) +- ***FedRoD*** -- [On Bridging Generic and Personalized Federated Learning for Image Classification](https://arxiv.org/abs/2107.00778) (ICLR'22) + ### FL Domain Generalization Methods - ***FedSR*** -- [FedSR: A Simple and Effective Domain Generalization Method for Federated Learning](https://openreview.net/forum?id=mrt90D00aQX) (NIPS'22) @@ -251,3 +253,24 @@ Medical Image Datasets - [*COVID-19*](https://www.researchgate.net/publication/344295900_Curated_Dataset_for_COVID-19_Posterior-Anterior_Chest_Radiography_Images_X-Rays) (3 x 244 x 224, 4 classes) - [*Organ-S/A/CMNIST*](https://medmnist.com/) (1 x 28 x 28, 11 classes) + +## Citation 🧐 + +``` +@software{Tan_FL-bench, + author = {Tan, Jiahao and Wang, Xinpeng}, + license = {GPL-2.0}, + title = {{FL-bench: A federated learning benchmark for solving image classification tasks}}, + url = {https://github.com/KarhouTam/FL-bench} +} + +@misc{tan2023pfedsim, + title={pFedSim: Similarity-Aware Model Aggregation Towards Personalized Federated Learning}, + author={Jiahao Tan and Yipeng Zhou and Gang Liu and Jessie Hui Wang and Shui Yu}, + year={2023}, + eprint={2305.15706}, + archivePrefix={arXiv}, + primaryClass={cs.LG} +} + +``` diff --git a/src/client/apfl.py b/src/client/apfl.py index 5b15f00..fe40c0e 100644 --- a/src/client/apfl.py +++ b/src/client/apfl.py @@ -69,7 +69,6 @@ def fit(self): logit_g = self.model(x) logit_p = self.alpha * logit_l + (1 - self.alpha) * logit_g.detach() loss = self.criterion(logit_p, y) - self.optimizer.zero_grad() loss.backward() self.optimizer.step() @@ -93,9 +92,9 @@ def update_alpha(self): self.alpha.data -= self.args.local_lr * alpha_grad self.alpha.clip_(0, 1.0) - def evaluate(self, model=None, test_flag=False): + def evaluate(self): return super().evaluate( - MixedModel(self.local_model, self.model, alpha=self.alpha), test_flag + model=MixedModel(self.local_model, self.model, alpha=self.alpha) ) diff --git a/src/client/ditto.py b/src/client/ditto.py index fae98e8..596d349 100644 --- a/src/client/ditto.py +++ b/src/client/ditto.py @@ -60,5 +60,5 @@ def fit(self): ) self.optimizer.step() - def evaluate(self, model=None, test_flag=False) -> Dict[str, float]: - return super().evaluate(self.pers_model, test_flag) + def evaluate(self): + return super().evaluate(self.pers_model) diff --git a/src/client/fedavg.py b/src/client/fedavg.py index abb1170..ab6a581 100644 --- a/src/client/fedavg.py +++ b/src/client/fedavg.py @@ -12,6 +12,7 @@ PROJECT_DIR = Path(__file__).parent.parent.parent.absolute() from src.utils.tools import trainable_params, evalutate_model, Logger +from src.utils.metrics import Metrics from src.utils.models import DecoupledModel from src.utils.constants import DATA_MEAN, DATA_STD from data.utils.datasets import DATASETS @@ -69,6 +70,7 @@ def __init__( self.trainset: Subset = Subset(self.dataset, indices=[]) self.valset: Subset = Subset(self.dataset, indices=[]) self.testset: Subset = Subset(self.dataset, indices=[]) + self.test_flag = False self.model = model.to(self.device) self.local_epoch = self.args.local_epoch @@ -115,50 +117,35 @@ def train_and_log(self, verbose=False) -> Dict[str, Dict[str, float]]: Returns: Dict[str, Dict[str, float]]: The logging info, which contains metric stats. """ - eval_results = { - "before": { - "train": {"loss": 0, "correct": 0, "size": 0}, - "val": {"loss": 0, "correct": 0, "size": 0}, - "test": {"loss": 0, "correct": 0, "size": 0}, - }, - "after": { - "train": {"loss": 0, "correct": 0, "size": 0}, - "val": {"loss": 0, "correct": 0, "size": 0}, - "test": {"loss": 0, "correct": 0, "size": 0}, - }, + eval_metrics = { + "before": {"train": Metrics(), "val": Metrics(), "test": Metrics()}, + "after": {"train": Metrics(), "val": Metrics(), "test": Metrics()}, } - eval_results["before"] = self.evaluate() + eval_metrics["before"] = self.evaluate() if self.local_epoch > 0: self.fit() self.save_state() - eval_results["after"] = self.evaluate() + eval_metrics["after"] = self.evaluate() if verbose: - colors = {"train": "yellow", "val": "green", "test": "cyan"} - for split, flag, subset in [ - ["train", self.args.eval_train, self.trainset], - ["val", self.args.eval_val, self.valset], - ["test", self.args.eval_test, self.testset], + for split, color, flag, subset in [ + ["train", "yellow", self.args.eval_train, self.trainset], + ["val", "green", self.args.eval_val, self.valset], + ["test", "cyan", self.args.eval_test, self.testset], ]: if len(subset) > 0 and flag: self.logger.log( "client [{}] [{}]({}) loss: {:.4f} -> {:.4f} accuracy: {:.2f}% -> {:.2f}%".format( self.client_id, - colors[split], + color, split, - eval_results["before"][split]["loss"] - / eval_results["before"][split]["size"], - eval_results["after"][split]["loss"] - / eval_results["after"][split]["size"], - eval_results["before"][split]["correct"] - / eval_results["before"][split]["size"] - * 100.0, - eval_results["after"][split]["correct"] - / eval_results["after"][split]["size"] - * 100.0, + eval_metrics["before"][split].loss, + eval_metrics["after"][split].loss, + eval_metrics["before"][split].accuracy, + eval_metrics["after"][split].accuracy, ) ) - return eval_results + return eval_metrics def set_parameters(self, new_parameters: OrderedDict[str, torch.Tensor]): """Load model parameters received from the server. @@ -256,37 +243,35 @@ def fit(self): self.optimizer.step() @torch.no_grad() - def evaluate( - self, model: torch.nn.Module = None, force_eval=False - ) -> Dict[str, float]: + def evaluate(self, model: torch.nn.Module = None) -> Dict[str, Metrics]: """The evaluation function. Would be activated before and after local training if `eval_test = True` or `eval_train = True`. Args: model (torch.nn.Module, optional): The target model needed evaluation (set to `None` for using `self.model`). Defaults to None. force_eval (bool, optional): Set as `True` when the server asking client to evaluate model. Returns: - Dict[str, float]: The evaluation metric stats. + Dict[str, Metrics]: The evaluation metric stats. """ # disable train data transform while evaluating self.dataset.enable_train_transform = False target_model = self.model if model is None else model target_model.eval() - train_loss, val_loss, test_loss = 0, 0, 0 - train_correct, val_correct, test_correct = 0, 0, 0 - train_size, val_size, test_size = 0, 0, 0 + train_metrics = Metrics() + val_metrics = Metrics() + test_metrics = Metrics() criterion = torch.nn.CrossEntropyLoss(reduction="sum") if len(self.testset) > 0 and self.args.eval_test: - test_loss, test_correct, test_size = evalutate_model( + test_metrics = evalutate_model( model=target_model, dataloader=self.testloader, criterion=criterion, device=self.device, ) - if len(self.valset) > 0 and (force_eval or self.args.eval_val): - val_loss, val_correct, val_size = evalutate_model( + if len(self.valset) > 0 and self.args.eval_val: + val_metrics = evalutate_model( model=target_model, dataloader=self.valloader, criterion=criterion, @@ -294,7 +279,7 @@ def evaluate( ) if len(self.trainset) > 0 and self.args.eval_train: - train_loss, train_correct, train_size = evalutate_model( + train_metrics = evalutate_model( model=target_model, dataloader=self.trainloader, criterion=criterion, @@ -302,27 +287,11 @@ def evaluate( ) self.dataset.enable_train_transform = True - return { - "train": { - "loss": train_loss, - "correct": train_correct, - "size": float(max(1, train_size)), - }, - "val": { - "loss": val_loss, - "correct": val_correct, - "size": float(max(1, val_size)), - }, - "test": { - "loss": test_loss, - "correct": test_correct, - "size": float(max(1, test_size)), - }, - } + return {"train": train_metrics, "val": val_metrics, "test": test_metrics} def test( self, client_id: int, new_parameters: OrderedDict[str, torch.Tensor] - ) -> Dict[str, Dict[str, float]]: + ) -> Dict[str, Dict[str, Metrics]]: """Test function. Only be activated while in FL test round. Args: @@ -330,30 +299,25 @@ def test( new_parameters (OrderedDict[str, torch.Tensor]): The FL model parameters. Returns: - Dict[str, Dict[str, float]]: the evalutaion metrics stats. + Dict[str, Dict[str, Metrics]]: the evalutaion metrics stats. """ + self.test_flag = True self.client_id = client_id self.load_dataset() self.set_parameters(new_parameters) - # set `size` as 1 for avoiding NaN. results = { - "before": { - "train": {"loss": 0, "correct": 0, "size": 1}, - "val": {"loss": 0, "correct": 0, "size": 1}, - "test": {"loss": 0, "correct": 0, "size": 1}, - }, - "after": { - "train": {"loss": 0, "correct": 0, "size": 1}, - "val": {"loss": 0, "correct": 0, "size": 1}, - "test": {"loss": 0, "correct": 0, "size": 1}, - }, + "before": {"train": Metrics(), "val": Metrics(), "test": Metrics()}, + "after": {"train": Metrics(), "val": Metrics(), "test": Metrics()}, } - results["before"] = self.evaluate(force_eval=True) + results["before"] = self.evaluate() if self.args.finetune_epoch > 0: + frz_params_dict = deepcopy(self.model.state_dict()) self.finetune() - results["after"] = self.evaluate(force_eval=True) + results["after"] = self.evaluate() + self.model.load_state_dict(frz_params_dict) + self.test_flag = False return results def finetune(self): diff --git a/src/client/fedfomo.py b/src/client/fedfomo.py index 89d7997..a10aa15 100644 --- a/src/client/fedfomo.py +++ b/src/client/fedfomo.py @@ -60,8 +60,7 @@ def set_parameters(self, received_params: Dict[int, List[torch.Tensor]]): dataloader=self.valloader, criterion=self.criterion, device=self.device, - )[0] - LOSS /= len(self.valset) + ).loss W = torch.zeros(len(received_params), device=self.device) self.weight_vector.zero_() with torch.no_grad(): @@ -75,8 +74,7 @@ def set_parameters(self, received_params: Dict[int, List[torch.Tensor]]): dataloader=self.valloader, criterion=self.criterion, device=self.device, - )[0] - loss /= len(self.valset) + ).loss params_diff = vectorize(params_i) - vectorized_self_params w = (LOSS - loss) / (torch.norm(params_diff) + 1e-5) W[i] = w diff --git a/src/client/fedrod.py b/src/client/fedrod.py new file mode 100644 index 0000000..63d03db --- /dev/null +++ b/src/client/fedrod.py @@ -0,0 +1,186 @@ +from argparse import Namespace +from collections import OrderedDict +from copy import deepcopy + +import torch +import torch.nn.functional as F + +from fedavg import FedAvgClient +from src.utils.models import DecoupledModel +from src.utils.tools import Logger, count_labels, trainable_params + + +def balanced_softmax_loss( + logits: torch.Tensor, + targets: torch.Tensor, + gamma: float, + label_counts: torch.Tensor, +): + logits = logits + (label_counts**gamma).unsqueeze(0).expand(logits.shape).log() + loss = F.cross_entropy(logits, targets, reduction="mean") + return loss + + +class FedRoDClient(FedAvgClient): + def __init__( + self, + model: DecoupledModel, + hypernetwork: torch.nn.Module, + args: Namespace, + logger: Logger, + device: torch.device, + ): + super().__init__(FedRoDModel(model, args.eval_per), args, logger, device) + self.hypernetwork: torch.nn.Module = None + self.hyper_optimizer = None + if self.args.hyper: + self.hypernetwork = hypernetwork.to(self.device) + self.hyper_optimizer = torch.optim.SGD( + trainable_params(self.hypernetwork), lr=self.args.hyper_lr + ) + self.personal_params_name.extend( + [key for key, _ in self.model.named_parameters() if "personalized" in key] + ) + + def set_parameters(self, new_generic_parameters: OrderedDict[str, torch.Tensor]): + personal_parameters = self.personal_params_dict.get( + self.client_id, self.init_personal_params_dict + ) + self.optimizer.load_state_dict( + self.opt_state_dict.get(self.client_id, self.init_opt_state_dict) + ) + self.model.generic_model.load_state_dict(new_generic_parameters, strict=False) + self.model.load_state_dict(personal_parameters, strict=False) + + def train( + self, + client_id: int, + local_epoch: int, + new_parameters: OrderedDict[str, torch.Tensor], + hyper_parameters: OrderedDict[str, torch.Tensor], + return_diff=False, + verbose=False, + ): + self.client_id = client_id + if self.args.hyper: + self.hypernetwork.load_state_dict(hyper_parameters, strict=False) + self.local_epoch = local_epoch + self.load_dataset() + self.set_parameters(new_parameters) + eval_results = self.train_and_log(verbose=verbose) + + if return_diff: + delta = OrderedDict() + for (name, p0), p1 in zip( + new_parameters.items(), trainable_params(self.model.generic_model) + ): + delta[name] = p0 - p1 + + hyper_delta = None + if self.args.hyper: + hyper_delta = OrderedDict() + for (name, p0), p1 in zip( + hyper_parameters.items(), trainable_params(self.hypernetwork) + ): + hyper_delta[name] = p0 - p1 + + return delta, hyper_delta, len(self.trainset), eval_results + else: + return ( + trainable_params(self.model.generic_model, detach=True), + trainable_params(self.hypernetwork, detach=True), + len(self.trainset), + eval_results, + ) + + def fit(self): + label_counts = torch.tensor( + count_labels(self.dataset, self.trainset.indices), device=self.device + ) + # if using hypernetwork for generating personalized classifier parameters and client is first-time selected + if self.args.hyper and self.client_id not in self.personal_params_dict: + label_distrib = label_counts / label_counts.sum() + classifier_params = self.hypernetwork(label_distrib) + clf_weight_numel = self.model.generic_model.classifier.weight.numel() + self.model.personalized_classifier.weight.data = ( + classifier_params[:clf_weight_numel] + .reshape(self.model.personalized_classifier.weight.shape) + .detach() + .clone() + ) + self.model.personalized_classifier.bias.data = ( + classifier_params[clf_weight_numel:] + .reshape(self.model.personalized_classifier.bias.shape) + .detach() + .clone() + ) + + self.model.train() + for _ in range(self.local_epoch): + for x, y in self.trainloader: + if len(x) <= 1: + continue + + x, y = x.to(self.device), y.to(self.device) + logit_g, logit_p = self.model(x) + loss_g = balanced_softmax_loss( + logit_g, y, self.args.gamma, label_counts + ) + loss_p = self.criterion(logit_p, y) + loss = loss_g + loss_p + self.optimizer.zero_grad() + loss.backward() + self.optimizer.step() + + if self.args.hyper and self.client_id not in self.personal_params_dict: + # This part has no references on the FedRoD paper + trained_classifier_params = torch.cat( + [ + torch.flatten(self.model.personalized_classifier.weight.data), + torch.flatten(self.model.personalized_classifier.bias.data), + ] + ) + hyper_loss = F.mse_loss( + classifier_params, trained_classifier_params, reduction="sum" + ) + self.hyper_optimizer.zero_grad() + hyper_loss.backward() + self.hyper_optimizer.step() + + def finetune(self): + self.model.train() + for _ in range(self.args.finetune_epoch): + for x, y in self.trainloader: + if len(x) <= 1: + continue + + x, y = x.to(self.device), y.to(self.device) + if self.args.eval_per: + _, logit_p = self.model(x) + loss = self.criterion(logit_p, y) + else: + logit_g, _ = self.model(x) + loss = self.criterion(logit_g, y) + self.optimizer.zero_grad() + loss.backward() + self.optimizer.step() + + +class FedRoDModel(DecoupledModel): + def __init__(self, generic_model: DecoupledModel, eval_per): + super().__init__() + self.generic_model = generic_model + self.personalized_classifier = deepcopy(generic_model.classifier) + self.eval_per = eval_per + + def forward(self, x): + z = torch.relu(self.generic_model.get_final_features(x, detach=False)) + logit_g = self.generic_model.classifier(z) + logit_p = self.personalized_classifier(z) + if self.training: + return logit_g, logit_p + else: + if self.eval_per: + return logit_p + else: + return logit_g diff --git a/src/client/knnper.py b/src/client/knnper.py index 15cdfe8..1fc9893 100644 --- a/src/client/knnper.py +++ b/src/client/knnper.py @@ -6,6 +6,7 @@ import faiss from fedavg import FedAvgClient +from src.utils.metrics import Metrics class kNNPerClient(FedAvgClient): @@ -14,16 +15,15 @@ def __init__(self, model, args, logger, device): self.datastore = DataStore(args, self.model.classifier.in_features) @torch.no_grad() - def evaluate(self, model=None, test_flag=False) -> Dict[str, float]: - if test_flag: + def evaluate(self, model=None): + if self.test_flag: self.dataset.enable_train_transform = False target_model = self.model if model is None else model target_model.eval() criterion = torch.nn.CrossEntropyLoss(reduction="sum") train_features, train_targets = [], [] - val_loss, test_loss = 0, 0 - val_correct, test_correct = 0, 0 - val_size, test_size = 0, 0 + val_metrics = Metrics() + test_metrics = Metrics() for x, y in self.trainloader: x, y = x.to(self.device), y.to(self.device) @@ -57,32 +57,18 @@ def _knnper_eval(dataloader): ) pred = torch.argmax(logits, dim=-1) loss = criterion(logits, targets).item() - correct = (pred == targets).sum().item() + return Metrics(loss, pred, targets) - return loss, correct, len(targets) - - if len(self.testset) > 0 and (test_flag and self.args.eval_test): - test_loss, test_correct, test_size = _knnper_eval(self.testloader) + if len(self.testset) > 0 and self.args.eval_test: + test_metrics = _knnper_eval(self.testloader) if len(self.valset) > 0 and self.args.eval_val: - val_loss, val_correct, val_size = _knnper_eval(self.valloader) + val_metrics = _knnper_eval(self.valloader) self.dataset.enable_train_transform = True # kNN-Per only do kNN trick in model test phase. So stats on training data are not offered. - return { - "train": {"loss": 0, "correct": 0, "size": 1.0}, - "val": { - "loss": val_loss, - "correct": val_correct, - "size": float(max(1, val_size)), - }, - "test": { - "loss": test_loss, - "correct": test_correct, - "size": float(max(1, test_size)), - }, - } + return {"train": Metrics(), "val": val_metrics, "test": test_metrics} else: - return super().evaluate(model, test_flag) + return super().evaluate(model) def get_knn_logits(self, features: torch.Tensor): distances, indices = self.datastore.index.search( diff --git a/src/client/metafed.py b/src/client/metafed.py index a5e187e..a49afac 100644 --- a/src/client/metafed.py +++ b/src/client/metafed.py @@ -35,11 +35,10 @@ def warmup(self, client_id, new_parameters): return trainable_params(self.model, detach=True) def update_flag(self): - _, val_correct, val_sample_num = evalutate_model( - self.model, self.valloader, device=self.device + metrics = evalutate_model(self.model, self.valloader, device=self.device) + self.client_flags[self.client_id] = ( + metrics.accuracy > self.args.threshold_1 ) - val_acc = val_correct / val_sample_num - self.client_flags[self.client_id] = val_acc > self.args.threshold_1 def train( self, @@ -84,14 +83,14 @@ def personalize( self.load_dataset() self.teacher.load_state_dict(teacher_parameters, strict=False) - _, student_correct, val_sample_num = evalutate_model( + student_metrics = evalutate_model( self.model, self.valloader, device=self.device ) - _, teacher_correct, _ = evalutate_model( + teacher_metrics = evalutate_model( self.teacher, self.valloader, device=self.device ) - teacher_acc = teacher_correct / val_sample_num - student_acc = student_correct / val_sample_num + teacher_acc = teacher_metrics.accuracy + student_acc = student_metrics.accuracy if teacher_acc <= student_acc and teacher_acc < self.args.threshold_2: self.lamda = 0 else: diff --git a/src/client/pfedme.py b/src/client/pfedme.py index 9df9680..1264edd 100644 --- a/src/client/pfedme.py +++ b/src/client/pfedme.py @@ -33,7 +33,6 @@ def train( self.load_dataset() self.set_parameters(new_parameters) self.local_parameters = trainable_params(new_parameters, detach=True) - # self.iter_trainloader = iter(self.trainloader) stats = self.train_and_log(verbose=verbose) return (deepcopy(self.local_parameters), len(self.trainset), stats) @@ -46,7 +45,6 @@ def save_state(self): def fit(self): self.model.train() for _ in range(self.args.local_epoch): - # x, y = self.get_data_batch() for x, y in self.trainloader: if len(x) <= 1: continue @@ -70,11 +68,11 @@ def fit(self): ) @torch.no_grad() - def evaluate(self, model=None, test_flag=False) -> Dict[str, Dict[str, float]]: + def evaluate(self) -> Dict[str, Dict[str, float]]: frz_model_params = deepcopy(self.model.state_dict()) if self.client_id in self.personalized_params_dict.keys(): self.model.load_state_dict(self.personalized_params_dict[self.client_id]) - res = super().evaluate(model, test_flag) + res = super().evaluate() self.model.load_state_dict(frz_model_params) return res diff --git a/src/server/adcol.py b/src/server/adcol.py index fd4b4d2..d8623eb 100644 --- a/src/server/adcol.py +++ b/src/server/adcol.py @@ -93,7 +93,7 @@ def train_one_round(self): ( delta, weight, - self.client_stats[client_id][self.current_epoch], + self.client_metrics[client_id][self.current_epoch], self.features[client_id], ) = self.trainer.train( client_id=client_id, diff --git a/src/server/cfl.py b/src/server/cfl.py index e87abc2..0d218fd 100644 --- a/src/server/cfl.py +++ b/src/server/cfl.py @@ -45,7 +45,7 @@ def train_one_round(self): ( delta, _, - self.client_stats[client_id][self.current_epoch], + self.client_metrics[client_id][self.current_epoch], ) = self.trainer.train( client_id=client_id, local_epoch=self.clients_local_epoch[client_id], @@ -116,8 +116,6 @@ def aggregate_clusterwise(self): param.data += diff self.delta_list = [None for _ in self.train_clients] - if self.current_epoch % 5 == 0: - print(self.client_clusters) @torch.no_grad() diff --git a/src/server/elastic.py b/src/server/elastic.py index 5f3e355..a62d09c 100644 --- a/src/server/elastic.py +++ b/src/server/elastic.py @@ -42,7 +42,7 @@ def train_one_round(self): ( delta, weight, - self.client_stats[client_id][self.current_epoch], + self.client_metrics[client_id][self.current_epoch], sensitivity, ) = self.trainer.train( client_id=client_id, diff --git a/src/server/fedap.py b/src/server/fedap.py index 59b165e..5374799 100644 --- a/src/server/fedap.py +++ b/src/server/fedap.py @@ -76,7 +76,7 @@ def train(self): delta_cache = [] weight_cache = [] for client_id in self.selected_clients: - (delta, weight, self.client_stats[client_id][E]) = self.trainer.train( + (delta, weight, self.client_metrics[client_id][E]) = self.trainer.train( client_id=client_id, local_epoch=self.clients_local_epoch[client_id], new_parameters=self.global_params_dict, @@ -147,7 +147,7 @@ def train(self): delta_cache = [] for client_id in self.selected_clients: client_local_params = self.generate_client_params(client_id) - delta, _, self.client_stats[client_id][E] = self.trainer.train( + delta, _, self.client_metrics[client_id][E] = self.trainer.train( client_id=client_id, local_epoch=self.clients_local_epoch[client_id], new_parameters=client_local_params, diff --git a/src/server/fedavg.py b/src/server/fedavg.py index c26fec7..e197f01 100644 --- a/src/server/fedavg.py +++ b/src/server/fedavg.py @@ -15,6 +15,8 @@ from rich.console import Console from rich.progress import track +from src.utils.metrics import Metrics + PROJECT_DIR = Path(__file__).parent.parent.parent.absolute() sys.path.append(PROJECT_DIR.as_posix()) @@ -27,7 +29,7 @@ trainable_params, get_optimal_cuda_device, ) -from src.utils.models import MODELS +from src.utils.models import MODELS, DecoupledModel from data.utils.datasets import DATASETS from src.client.fedavg import FedAvgClient @@ -113,7 +115,9 @@ def __init__( # get_model_arch() would return a class depends on model's name, # then init the model object by indicating the dataset and calling the class. # Finally transfer the model object to the target device. - self.model = MODELS[self.args.model](dataset=self.args.dataset).to(self.device) + self.model: DecoupledModel = MODELS[self.args.model]( + dataset=self.args.dataset + ).to(self.device) self.model.check_avaliability() # client_trainable_params is for pFL, which outputs exclusive model per client @@ -195,18 +199,10 @@ def __init__( + f"_{self.args.global_epoch}" + f"_{self.args.local_epoch}" ) - self.client_stats = {i: {} for i in self.train_clients} - self.metrics = { - "before": { - "train": {"accuracy": []}, - "val": {"accuracy": []}, - "test": {"accuracy": []}, - }, - "after": { - "train": {"accuracy": []}, - "val": {"accuracy": []}, - "test": {"accuracy": []}, - }, + self.client_metrics = {i: {} for i in self.train_clients} + self.global_metrics = { + "before": {"train": [], "val": [], "test": []}, + "after": {"train": [], "val": [], "test": []}, } stdout = Console(log_path=False, log_time=False) self.logger = Logger( @@ -217,7 +213,7 @@ def __init__( / self.output_dir / f"{self.args.dataset}_log.html", ) - self.eval_results: Dict[int, Dict[str, str]] = {} + self.test_results: Dict[int, Dict[str, Dict[str, Metrics]]] = {} self.train_progress_bar = track( range(self.args.global_epoch), "[bold green]Training...", console=stdout ) @@ -266,7 +262,7 @@ def train_one_round(self): ( delta, weight, - self.client_stats[client_id][self.current_epoch], + self.client_metrics[client_id][self.current_epoch], ) = self.trainer.train( client_id=client_id, local_epoch=self.clients_local_epoch[client_id], @@ -283,95 +279,72 @@ def test(self): """The function for testing FL method's output (a single global model or personalized client models).""" self.test_flag = True client_ids = set(self.val_clients + self.test_clients) - split_sample_flag = False + all_same = False if client_ids: - if (set(self.train_clients) != set(self.val_clients)) or ( - set(self.train_clients) != set(self.test_clients) - ): + if self.val_clients == self.train_clients == self.test_clients: + all_same = True results = { - "val_clients": { + "all_clients": { "before": { - "train": {"loss": [], "correct": [], "size": []}, - "val": {"loss": [], "correct": [], "size": []}, - "test": {"loss": [], "correct": [], "size": []}, + "train": Metrics(), + "val": Metrics(), + "test": Metrics(), }, "after": { - "train": {"loss": [], "correct": [], "size": []}, - "val": {"loss": [], "correct": [], "size": []}, - "test": {"loss": [], "correct": [], "size": []}, + "train": Metrics(), + "val": Metrics(), + "test": Metrics(), }, - }, - "test_clients": { + } + } + else: + results = { + "val_clients": { "before": { - "train": {"loss": [], "correct": [], "size": []}, - "val": {"loss": [], "correct": [], "size": []}, - "test": {"loss": [], "correct": [], "size": []}, + "train": Metrics(), + "val": Metrics(), + "test": Metrics(), }, "after": { - "train": {"loss": [], "correct": [], "size": []}, - "val": {"loss": [], "correct": [], "size": []}, - "test": {"loss": [], "correct": [], "size": []}, + "train": Metrics(), + "val": Metrics(), + "test": Metrics(), }, }, - } - else: - split_sample_flag = True - results = { - "all_clients": { + "test_clients": { "before": { - "train": {"loss": [], "correct": [], "size": []}, - "val": {"loss": [], "correct": [], "size": []}, - "test": {"loss": [], "correct": [], "size": []}, + "train": Metrics(), + "val": Metrics(), + "test": Metrics(), }, "after": { - "train": {"loss": [], "correct": [], "size": []}, - "val": {"loss": [], "correct": [], "size": []}, - "test": {"loss": [], "correct": [], "size": []}, + "train": Metrics(), + "val": Metrics(), + "test": Metrics(), }, - } + }, } for cid in client_ids: client_local_params = self.generate_client_params(cid) - stats = self.trainer.test(cid, client_local_params) + client_metrics = self.trainer.test(cid, client_local_params) for stage in ["before", "after"]: for split in ["train", "val", "test"]: - for metric in ["loss", "correct", "size"]: - if split_sample_flag: - results["all_clients"][stage][split][metric].append( - stats[stage][split][metric] - ) - else: - if cid in self.val_clients: - results["val_clients"][stage][split][metric].append( - stats[stage][split][metric] - ) - if cid in self.test_clients: - results["test_clients"][stage][split][ - metric - ].append(stats[stage][split][metric]) - for group in results.keys(): - for stage in ["before", "after"]: - for split in ["train", "val", "test"]: - for metric in ["loss", "correct", "size"]: - results[group][stage][split][metric] = torch.tensor( - results[group][stage][split][metric] - ) - num_samples = results[group][stage][split]["size"].sum() - if num_samples > 0: - results[group][stage][split]["accuracy"] = ( - results[group][stage][split]["correct"].sum() - / num_samples - * 100 - ) - results[group][stage][split]["loss"] = ( - results[group][stage][split]["loss"].sum() / num_samples + if all_same: + results["all_clients"][stage][split].update( + client_metrics[stage][split] ) else: - results[group][stage][split]["accuracy"] = 0 - results[group][stage][split]["loss"] = 0 + if cid in self.val_clients: + results["val_clients"][stage][split].update( + client_metrics[stage][split] + ) + if cid in self.test_clients: + results["test_clients"][stage][split].update( + client_metrics[stage][split] + ) - self.eval_results[self.current_epoch + 1] = results + self.test_results[self.current_epoch + 1] = results self.test_flag = False @@ -461,70 +434,56 @@ def show_convergence(self): self.logger.log("=" * 10, self.algo, "Convergence on train clients", "=" * 10) for stage in ["before", "after"]: for split in ["train", "val", "test"]: - self.logger.log( - f"[{colors[split]}]{split}[/{colors[split]}] [{colors[stage]}]({stage} local training):" - ) - acc_range = [90.0, 80.0, 70.0, 60.0, 50.0, 40.0, 30.0, 20.0, 10.0] - min_acc_idx = 10 - max_acc = 0 - for E, acc in enumerate(self.metrics[stage][split]["accuracy"]): - for i, target in enumerate(acc_range): - if acc >= target and acc > max_acc: - self.logger.log(f"{target}%({acc:.2f}%) at epoch: {E}") - max_acc = acc - min_acc_idx = i - break - acc_range = acc_range[:min_acc_idx] + if len(self.global_metrics[stage][split]) > 0: + self.logger.log( + f"[{colors[split]}]{split}[/{colors[split]}] [{colors[stage]}]({stage} local training):" + ) + acc_range = [90.0, 80.0, 70.0, 60.0, 50.0, 40.0, 30.0, 20.0, 10.0] + min_acc_idx = 10 + max_acc = 0 + accuracies = [ + metrics.accuracy + for metrics in self.global_metrics[stage][split] + ] + for E, acc in enumerate(accuracies): + for i, target in enumerate(acc_range): + if acc >= target and acc > max_acc: + self.logger.log(f"{target}%({acc:.2f}%) at epoch: {E}") + max_acc = acc + min_acc_idx = i + break + acc_range = acc_range[:min_acc_idx] def log_info(self): """This function is for logging each selected client's training info.""" - for split in ["train", "val", "test"]: - correct_before = torch.tensor( - [ - self.client_stats[i][self.current_epoch]["before"][split]["correct"] - for i in self.selected_clients - ] - ) - correct_after = torch.tensor( - [ - self.client_stats[i][self.current_epoch]["after"][split]["correct"] - for i in self.selected_clients - ] - ) - num_samples = torch.tensor( - [ - self.client_stats[i][self.current_epoch]["before"][split]["size"] - for i in self.selected_clients - ] - ) - acc_before = ( - correct_before.sum(dim=-1, keepdim=True) / num_samples.sum() * 100.0 - ).item() - acc_after = ( - correct_after.sum(dim=-1, keepdim=True) / num_samples.sum() * 100.0 - ).item() - self.metrics["before"][split]["accuracy"].append(acc_before) - self.metrics["after"][split]["accuracy"].append(acc_after) - if self.args.visible: - self.viz.line( - [acc_before], - [self.current_epoch], - win=self.viz_win_name, - update="append", - name=f"{split}(before)", - opts=dict( - title=self.viz_win_name, - xlabel="Communication Rounds", - ylabel="Accuracy", - ), - ) - self.viz.line( - [acc_after], - [self.current_epoch], - win=self.viz_win_name, - update="append", - name=f"{split}(after)", - ) + for stage in ["before", "after"]: + for split, flag in [ + ("train", self.args.eval_train), + ("val", self.args.eval_val), + ("test", self.args.eval_test), + ]: + if flag: + global_metrics = Metrics() + for i in self.selected_clients: + global_metrics.update( + self.client_metrics[i][self.current_epoch][stage][split] + ) + + self.global_metrics[stage][split].append(global_metrics) + + if self.args.visible: + self.viz.line( + [global_metrics.accuracy], + [self.current_epoch], + win=self.viz_win_name, + update="append", + name=f"{split}({stage})", + opts=dict( + title=self.viz_win_name, + xlabel="Communication Rounds", + ylabel="Accuracy", + ), + ) def log_max_metrics(self): self.logger.log("=" * 20, self.algo, "Max Accuracy", "=" * 20) @@ -538,23 +497,35 @@ def log_max_metrics(self): } groups = ["val_clients", "test_clients"] - if set(self.train_clients) == set(self.val_clients) == set(self.test_clients): + if self.train_clients == self.val_clients == self.test_clients: groups = ["all_clients"] for group in groups: self.logger.log(f"{group}:") for stage in ["before", "after"]: - for split in ["train", "val", "test"]: - epoch, max_acc = max( - [ - (epoch, results[group][stage][split]["accuracy"]) - for epoch, results in self.eval_results.items() - ], - key=lambda tup: tup[1], - ) - self.logger.log( - f"[{colors[split]}]({split})[/{colors[split]}] [{colors[stage]}]{stage}[/{colors[stage]}] fine-tuning: {max_acc:.2f}% at epoch {epoch}" - ) + for split, flag in [ + ("train", self.args.eval_train), + ("val", self.args.eval_val), + ("test", self.args.eval_test), + ]: + if flag: + metrics_list = list( + map( + lambda tup: (tup[0], tup[1][group][stage][split]), + self.test_results.items(), + ) + ) + if len(metrics_list) > 0: + epoch, max_acc = max( + [ + (epoch, metrics.accuracy) + for epoch, metrics in metrics_list + ], + key=lambda tup: tup[1], + ) + self.logger.log( + f"[{colors[split]}]({split})[/{colors[split]}] [{colors[stage]}]{stage}[/{colors[stage]}] fine-tuning: {max_acc:.2f}% at epoch {epoch}" + ) def run(self): """The comprehensive FL process. @@ -586,14 +557,19 @@ def run(self): epoch: { group: { split: { - "loss": f"{metrics['before'][split]['loss']:.4f} -> {metrics['after'][split]['loss']:.4f}", - "accuracy": f"{metrics['before'][split]['accuracy']:.2f}% -> {metrics['after'][split]['accuracy']:.2f}%", + "loss": f"{metrics['before'][split].loss:.4f} -> {metrics['after'][split].loss:.4f}", + "accuracy": f"{metrics['before'][split].accuracy:.2f}% -> {metrics['after'][split].accuracy:.2f}%", } - for split in ["train", "val", "test"] + for split, flag in [ + ("train", self.args.eval_train), + ("val", self.args.eval_val), + ("test", self.args.eval_test), + ] + if flag } for group, metrics in results.items() } - for epoch, results in self.eval_results.items() + for epoch, results in self.test_results.items() } ) @@ -613,9 +589,12 @@ def run(self): } for stage in ["before", "after"]: for split in ["train", "val", "test"]: - if len(self.metrics[stage][split]["accuracy"]) > 0: + if len(self.global_metrics[stage][split]) > 0: plt.plot( - self.metrics[stage][split]["accuracy"], + [ + metrics.accuracy + for metrics in self.global_metrics[stage][split] + ], label=f"{split}_{stage}", ls=linestyle[stage][split], ) @@ -635,8 +614,18 @@ def run(self): df = pd.DataFrame() for stage in ["before", "after"]: for split in ["train", "val", "test"]: - for metric, stats in self.metrics[stage][split].items(): - if len(stats) > 0: + if len(self.global_metrics[stage][split]) > 0: + for metric in [ + "accuracy", + "micro_precision", + "macro_precision", + "micro_recall", + "macro_recall", + ]: + stats = [ + getattr(metrics, metric) + for metrics in self.global_metrics[stage][split] + ] df.insert( loc=df.shape[1], column=f"{metric}_{split}_{stage}", diff --git a/src/server/feddyn.py b/src/server/feddyn.py index 95c84b4..20846f0 100644 --- a/src/server/feddyn.py +++ b/src/server/feddyn.py @@ -49,7 +49,7 @@ def train_one_round(self): ( delta, _, - self.client_stats[client_id][self.current_epoch], + self.client_metrics[client_id][self.current_epoch], ) = self.trainer.train( client_id=client_id, local_epoch=self.clients_local_epoch[client_id], diff --git a/src/server/fedfomo.py b/src/server/fedfomo.py index 3032403..31c1697 100644 --- a/src/server/fedfomo.py +++ b/src/server/fedfomo.py @@ -38,7 +38,7 @@ def train_one_round(self): ( client_params, weight_vector, - self.client_stats[client_id][self.current_epoch], + self.client_metrics[client_id][self.current_epoch], ) = self.trainer.train( client_id=client_id, local_epoch=self.clients_local_epoch[client_id], diff --git a/src/server/fedgen.py b/src/server/fedgen.py index dd70224..f2bfb69 100644 --- a/src/server/fedgen.py +++ b/src/server/fedgen.py @@ -66,7 +66,7 @@ def train_one_round(self): delta, weight, label_counts, - self.client_stats[client_id][self.current_epoch], + self.client_metrics[client_id][self.current_epoch], ) = self.trainer.train( client_id=client_id, local_epoch=self.clients_local_epoch[client_id], diff --git a/src/server/fediir.py b/src/server/fediir.py index 0fcaee0..c13ea06 100644 --- a/src/server/fediir.py +++ b/src/server/fediir.py @@ -45,7 +45,7 @@ def train_one_round(self): ( delta, _, - self.client_stats[client_id][self.current_epoch], + self.client_metrics[client_id][self.current_epoch], ) = self.trainer.train( client_id=client_id, local_epoch=self.clients_local_epoch[client_id], diff --git a/src/server/fedmd.py b/src/server/fedmd.py index 8381aa7..d31f16d 100644 --- a/src/server/fedmd.py +++ b/src/server/fedmd.py @@ -74,7 +74,7 @@ def train_one_round(self): client_params = self.generate_client_params(client_id) ( client_params, - self.client_stats[client_id][self.current_epoch], + self.client_metrics[client_id][self.current_epoch], ) = self.trainer.train( client_id=client_id, local_epoch=self.clients_local_epoch[client_id], diff --git a/src/server/fedopt.py b/src/server/fedopt.py index 0c72ea4..335f3f4 100644 --- a/src/server/fedopt.py +++ b/src/server/fedopt.py @@ -54,7 +54,7 @@ def train_one_round(self): ( delta, weight, - self.client_stats[client_id][self.current_epoch], + self.client_metrics[client_id][self.current_epoch], ) = self.trainer.train( client_id=client_id, local_epoch=self.clients_local_epoch[client_id], diff --git a/src/server/fedrod.py b/src/server/fedrod.py new file mode 100644 index 0000000..a3efad8 --- /dev/null +++ b/src/server/fedrod.py @@ -0,0 +1,146 @@ +from argparse import Namespace +from collections import OrderedDict +from copy import deepcopy +from typing import List, OrderedDict + +import torch +import torch.nn as nn + +from fedavg import FedAvgServer, get_fedavg_argparser +from src.client.fedrod import FedRoDClient +from src.utils.tools import trainable_params +from src.utils.constants import NUM_CLASSES + + +def get_fedrod_argparser(): + parser = get_fedavg_argparser() + parser.add_argument("--gamma", type=float, default=1) + parser.add_argument("--hyper", type=int, default=0) + parser.add_argument("--hyper_lr", type=float, default=0.1) + parser.add_argument("--hyper_hidden_dim", type=int, default=32) + parser.add_argument("--eval_per", type=int, default=1) + return parser + + +class FedRoDServer(FedAvgServer): + def __init__( + self, + algo: str = "FedRoD", + args: Namespace = None, + unique_model=False, + default_trainer=False, + ): + super().__init__(algo, args, unique_model, default_trainer) + self.hyper_params_dict = None + self.hypernetwork: nn.Module = None + if self.args.hyper: + output_dim = ( + self.model.classifier.weight.numel() + + self.model.classifier.bias.numel() + ) + input_dim = NUM_CLASSES[self.args.dataset] + self.hypernetwork = HyperNetwork( + input_dim, self.args.hyper_hidden_dim, output_dim + ).to(self.device) + params, keys = trainable_params( + self.hypernetwork, detach=True, requires_name=True + ) + self.hyper_params_dict = OrderedDict(zip(keys, params)) + self.trainer = FedRoDClient( + model=deepcopy(self.model), + hypernetwork=deepcopy(self.hypernetwork), + args=self.args, + logger=self.logger, + device=self.device, + ) + + def train_one_round(self): + delta_cache = [] + weight_cache = [] + hyper_delta_cache = [] + for client_id in self.selected_clients: + client_local_params = self.generate_client_params(client_id) + ( + delta, + hyper_delta, + weight, + self.client_metrics[client_id][self.current_epoch], + ) = self.trainer.train( + client_id=client_id, + local_epoch=self.clients_local_epoch[client_id], + new_parameters=client_local_params, + hyper_parameters=self.hyper_params_dict, + verbose=((self.current_epoch + 1) % self.args.verbose_gap) == 0, + return_diff=False, + ) + + delta_cache.append(delta) + weight_cache.append(weight) + hyper_delta_cache.append(hyper_delta) + + self.aggregate(delta_cache, hyper_delta_cache, weight_cache) + + @torch.no_grad() + def aggregate( + self, + delta_cache: List[OrderedDict[str, torch.Tensor]], + hyper_delta_cache: List[OrderedDict[str, torch.Tensor]], + weight_cache: List[int], + return_diff=False, + ): + weights = torch.tensor(weight_cache, device=self.device) / sum(weight_cache) + if return_diff: + delta_list = [list(delta.values()) for delta in delta_cache] + aggregated_delta = [ + torch.sum(weights * torch.stack(diff, dim=-1), dim=-1) + for diff in zip(*delta_list) + ] + for param, diff in zip(self.global_params_dict.values(), aggregated_delta): + param.data -= diff + + if self.args.hyper: + hyper_delta_list = [list(delta.values()) for delta in delta_cache] + aggregated_hyper_delta = [ + torch.sum(weights * torch.stack(diff, dim=-1), dim=-1) + for diff in zip(*hyper_delta_list) + ] + for param, diff in zip( + self.hyper_params_dict.values(), aggregated_hyper_delta + ): + param.data -= diff + + else: + for old_param, zipped_new_param in zip( + self.global_params_dict.values(), zip(*delta_cache) + ): + old_param.data = (torch.stack(zipped_new_param, dim=-1) * weights).sum( + dim=-1 + ) + self.model.load_state_dict(self.global_params_dict, strict=False) + + if self.args.hyper: + for old_param, zipped_new_param in zip( + self.hyper_params_dict.values(), zip(*hyper_delta_cache) + ): + old_param.data = ( + torch.stack(zipped_new_param, dim=-1) * weights + ).sum(dim=-1) + + if self.args.hyper: + self.hypernetwork.load_state_dict(self.hyper_params_dict) + self.model.load_state_dict(self.global_params_dict, strict=False) + + +class HyperNetwork(nn.Module): + def __init__(self, input_dim: int, hidden_dim: int, output_dim: int) -> None: + super().__init__() + self.model = nn.Sequential( + nn.Linear(input_dim, hidden_dim), + nn.ReLU(True), + nn.Linear(hidden_dim, output_dim), + ) + + def forward(self, x): + return self.model(x) + + diff --git a/src/server/fedsr.py b/src/server/fedsr.py index 5b1b437..87431ea 100644 --- a/src/server/fedsr.py +++ b/src/server/fedsr.py @@ -38,7 +38,7 @@ def __init__(self, base_model: DecoupledModel, dataset) -> None: def featurize(self, x, num_samples=1, return_dist=False): # designed for FedSR - z_params = self.map_layer(self.base(x)) + z_params = F.relu(self.map_layer(F.relu(self.base(x)))) z_mu = z_params[:, : self.z_dim] z_sigma = F.softplus(z_params[:, self.z_dim :]) z_dist = distrib.Independent(distrib.normal.Normal(z_mu, z_sigma), 1) diff --git a/src/server/local.py b/src/server/local.py index 26155d7..2afa6c7 100644 --- a/src/server/local.py +++ b/src/server/local.py @@ -21,7 +21,7 @@ def train_one_round(self): ( client_params, _, - self.client_stats[client_id][self.current_epoch], + self.client_metrics[client_id][self.current_epoch], ) = self.trainer.train( client_id=client_id, local_epoch=self.clients_local_epoch[client_id], diff --git a/src/server/metafed.py b/src/server/metafed.py index 77da5b7..36c944c 100644 --- a/src/server/metafed.py +++ b/src/server/metafed.py @@ -74,7 +74,7 @@ def train(self): teacher_params = self.generate_client_params( (client_id + self.client_num - 1) % self.client_num ) - student_params, self.client_stats[client_id][E] = self.trainer.train( + student_params, self.client_metrics[client_id][E] = self.trainer.train( client_id=client_id, local_epoch=self.clients_local_epoch[client_id], student_parameters=student_params, @@ -107,7 +107,7 @@ def train(self): ( student_params, - self.client_stats[client_id][self.current_epoch], + self.client_metrics[client_id][self.current_epoch], ) = self.trainer.personalize( client_id=client_id, student_parameters=student_params, diff --git a/src/server/pfedhn.py b/src/server/pfedhn.py index 4c824af..0734c8d 100644 --- a/src/server/pfedhn.py +++ b/src/server/pfedhn.py @@ -79,7 +79,7 @@ def train_one_round(self): ( delta, weight, - self.client_stats[client_id][self.current_epoch], + self.client_metrics[client_id][self.current_epoch], ) = self.trainer.train( client_id=client_id, local_epoch=self.clients_local_epoch[client_id], diff --git a/src/server/pfedla.py b/src/server/pfedla.py index e92e248..eace2e5 100644 --- a/src/server/pfedla.py +++ b/src/server/pfedla.py @@ -61,7 +61,7 @@ def train_one_round(self) -> None: ( delta, _, - self.client_stats[client_id][self.current_epoch], + self.client_metrics[client_id][self.current_epoch], ) = self.trainer.train( client_id=client_id, local_epoch=self.clients_local_epoch[client_id], diff --git a/src/server/pfedsim.py b/src/server/pfedsim.py index 83db0a4..98c65b5 100644 --- a/src/server/pfedsim.py +++ b/src/server/pfedsim.py @@ -79,7 +79,7 @@ def train(self): ( client_params, _, - self.client_stats[client_id][E], + self.client_metrics[client_id][E], ) = self.trainer.train( client_id=client_id, local_epoch=self.clients_local_epoch[client_id], diff --git a/src/server/scaffold.py b/src/server/scaffold.py index 509233f..65cb08a 100644 --- a/src/server/scaffold.py +++ b/src/server/scaffold.py @@ -41,7 +41,7 @@ def train_one_round(self): ( y_delta, c_delta, - self.client_stats[client_id][self.current_epoch], + self.client_metrics[client_id][self.current_epoch], ) = self.trainer.train( client_id=client_id, local_epoch=self.clients_local_epoch[client_id], diff --git a/src/utils/metrics.py b/src/utils/metrics.py new file mode 100644 index 0000000..ab414e4 --- /dev/null +++ b/src/utils/metrics.py @@ -0,0 +1,77 @@ +import numpy as np +import torch +from sklearn import metrics + + +def to_numpy(x): + if isinstance(x, torch.Tensor): + return x.cpu().numpy() + elif isinstance(x, list): + return np.array(x) + else: + raise TypeError( + f"input data should be torch.Tensor or built-in list. Now {type(x)}" + ) + + +class Metrics: + def __init__(self, loss=None, predicts=None, targets=None): + self._loss = loss if loss is not None else 0.0 + self._targets = targets if targets is not None else [] + self._predicts = predicts if predicts is not None else [] + + def update(self, other): + if other is not None: + self._predicts.extend(to_numpy(other._predicts)) + self._targets.extend(to_numpy(other._targets)) + self._loss += other._loss + + def _calculate(self, metric, **kwargs): + return metric(self._targets, self._predicts, **kwargs) + + @property + def loss(self): + try: + loss = self._loss / len(self._targets) + except ZeroDivisionError: + return 0 + return loss + + @property + def macro_precision(self): + score = self._calculate( + metrics.precision_score, average="macro", zero_division=0 + ) + return score * 100 + + @property + def macro_recall(self): + score = self._calculate(metrics.recall_score, average="macro", zero_division=0) + return score * 100 + + @property + def micro_precision(self): + score = self._calculate( + metrics.precision_score, average="micro", zero_division=0 + ) + return score * 100 + + @property + def micro_recall(self): + score = self._calculate(metrics.recall_score, average="micro", zero_division=0) + return score * 100 + + @property + def accuracy(self): + if self.size == 0: + return 0 + score = self._calculate(metrics.accuracy_score) + return score * 100 + + @property + def corrects(self): + return self._calculate(metrics.accuracy_score, normalize=False) + + @property + def size(self): + return len(self._targets) diff --git a/src/utils/tools.py b/src/utils/tools.py index d24d8ed..2cbd2c9 100644 --- a/src/utils/tools.py +++ b/src/utils/tools.py @@ -14,6 +14,7 @@ from rich.console import Console from data.utils.datasets import BaseDataset +from src.utils.metrics import Metrics PROJECT_DIR = Path(__file__).parent.parent.parent.absolute() OUT_DIR = PROJECT_DIR / "out" @@ -127,8 +128,8 @@ def evalutate_model( dataloader: DataLoader, criterion=torch.nn.CrossEntropyLoss(reduction="sum"), device=torch.device("cpu"), -) -> Tuple[float, float, int]: - """For evaluating the `model` over `dataloader` and return the result calculated by `criterion`. +) -> Metrics: + """For evaluating the `model` over `dataloader` and return metrics. Args: model (torch.nn.Module): Target model. @@ -137,20 +138,17 @@ def evalutate_model( device (torch.device, optional): The device that holds the computation. Defaults to torch.device("cpu"). Returns: - Tuple[float, float, int]: [loss, correct num, sample num] + Metrics: The metrics objective. """ model.eval() - correct = 0 - loss = 0 - sample_num = 0 + metrics = Metrics() for x, y in dataloader: x, y = x.to(device), y.to(device) logits = model(x) - loss += criterion(logits, y).item() + loss = criterion(logits, y).item() pred = torch.argmax(logits, -1) - correct += (pred == y).sum().item() - sample_num += len(y) - return loss, correct, sample_num + metrics.update(Metrics(loss, pred, y)) + return metrics def count_labels(