diff --git a/.environment/Dockerfile b/.environment/Dockerfile index 395b229..ef0f894 100644 --- a/.environment/Dockerfile +++ b/.environment/Dockerfile @@ -13,5 +13,5 @@ RUN pip install --upgrade pip && \ pip install -r /tmp/requirements.txt -# port for visdom -EXPOSE 8097 8265 +# visdom ray-dashboard tensorboard +EXPOSE 8097 8265 6006 diff --git a/.github/workflows/docker-publish.yml b/.github/workflows/docker-publish.yml index 4792a80..89445ce 100644 --- a/.github/workflows/docker-publish.yml +++ b/.github/workflows/docker-publish.yml @@ -1,72 +1,126 @@ -name: Docker-Publish +name: Docker Image Publish on: push: paths: - - '.environment/**' - - '.github/workflows/docker-publish.yml' - + - ".environment/**" + - ".github/workflows/docker-publish.yml" env: + IMAGE_LOWERCASE_NAME: fl-bench + IMAGE_LOWERCASE_OWNER: karhoutam GITHUB_REGISTRY: ghcr.io ALIYUN_REGISTRY: registry.cn-hangzhou.aliyuncs.com DOCKERHUB_REGISTRY: docker.io IMAGE_TAG: master jobs: - build: + build-image: + name: Build Docker Image + runs-on: ubuntu-latest + + steps: + - name: Checkout repository + uses: actions/checkout@v4 + + - name: Set up Docker Buildx + uses: docker/setup-buildx-action@v3 + # https://github.com/docker/build-push-action + - name: Build and push Docker image + uses: docker/build-push-action@v5 + with: + context: . + file: .environment/Dockerfile + push: false + tags: | + ${{ env.GITHUB_REGISTRY }}/${{ env.IMAGE_LOWERCASE_OWNER }}/${{ env.IMAGE_LOWERCASE_NAME }}:${{ env.IMAGE_TAG }} + ${{ env.ALIYUN_REGISTRY }}/${{ env.IMAGE_LOWERCASE_OWNER }}/${{ env.IMAGE_LOWERCASE_NAME }}:${{ env.IMAGE_TAG }} + ${{ env.DOCKERHUB_REGISTRY }}/${{ env.IMAGE_LOWERCASE_OWNER }}/${{ env.IMAGE_LOWERCASE_NAME }}:${{ env.IMAGE_TAG }} + cache-to: type=gha,mode=max + + push-ghcr: + name: Push to ghcr + needs: build-image runs-on: ubuntu-latest - permissions: - contents: read - packages: write - id-token: write steps: - - name: Set Lowercase Variables - run: | - echo "IMAGE_LOWERCASE_OWNER=$(echo ${{ github.actor }} | tr '[:upper:]' '[:lower:]')" >> $GITHUB_ENV - echo "IMAGE_LOWERCASE_NAME=$(echo ${{ github.event.repository.name }} | tr '[:upper:]' '[:lower:]')" >> $GITHUB_ENV - name: Checkout repository uses: actions/checkout@v4 - # https://github.com/docker/setup-buildx-action + - name: Set up Docker Buildx uses: docker/setup-buildx-action@v3 - # https://github.com/docker/login-action + - name: Log into ghcr.io - if: github.event_name != 'pull_request' uses: docker/login-action@v3 with: registry: ${{ env.GITHUB_REGISTRY }} username: ${{ github.actor }} password: ${{ secrets.GITHUB_TOKEN }} - - name: Log into Docker Hub - if: github.event_name != 'pull_request' + + - name: Build and push Docker image + uses: docker/build-push-action@v5 + with: + context: . + file: .environment/Dockerfile + push: true + tags: | + ${{ env.GITHUB_REGISTRY }}/${{ env.IMAGE_LOWERCASE_OWNER }}/${{ env.IMAGE_LOWERCASE_NAME }}:${{ env.IMAGE_TAG }} + cache-from: type=gha + push-dockerhub: + name: Push to dockerhub + needs: build-image + runs-on: ubuntu-latest + + steps: + - name: Checkout repository + uses: actions/checkout@v4 + + - name: Set up Docker Buildx + uses: docker/setup-buildx-action@v3 + + - name: Log into ghcr.io uses: docker/login-action@v3 with: registry: ${{ env.DOCKERHUB_REGISTRY }} username: ${{ secrets.DOCKERHUB_USERNAME }} password: ${{ secrets.DOCKERHUB_TOKEN }} - - name: Log into Aliyun - if: github.event_name != 'pull_request' + + - name: Build and push Docker image + uses: docker/build-push-action@v5 + with: + context: . + file: .environment/Dockerfile + push: true + tags: | + ${{ env.DOCKERHUB_REGISTRY }}/${{ env.IMAGE_LOWERCASE_OWNER }}/${{ env.IMAGE_LOWERCASE_NAME }}:${{ env.IMAGE_TAG }} + cache-from: type=gha + + push-aliyun: + name: Push to aliyun + needs: build-image + runs-on: ubuntu-latest + + steps: + - name: Checkout repository + uses: actions/checkout@v4 + + - name: Set up Docker Buildx + uses: docker/setup-buildx-action@v3 + + - name: Log into ghcr.io uses: docker/login-action@v3 with: registry: ${{ env.ALIYUN_REGISTRY }} username: ${{ secrets.ALIYUN_USERNAME }} password: ${{ secrets.ALIYUN_TOKEN }} - # https://github.com/docker/build-push-action + - name: Build and push Docker image uses: docker/build-push-action@v5 with: context: . file: .environment/Dockerfile - push: ${{ github.event_name != 'pull_request' }} - build-args: | - REPO_PATH=. - REPO_NAME=${{ github.event.repository.name }} + push: true tags: | - ${{ env.GITHUB_REGISTRY }}/${{ env.IMAGE_LOWERCASE_OWNER }}/${{ env.IMAGE_LOWERCASE_NAME }}:${{ env.IMAGE_TAG }} ${{ env.ALIYUN_REGISTRY }}/${{ env.IMAGE_LOWERCASE_OWNER }}/${{ env.IMAGE_LOWERCASE_NAME }}:${{ env.IMAGE_TAG }} - ${{ env.DOCKERHUB_REGISTRY }}/${{ env.IMAGE_LOWERCASE_OWNER }}/${{ env.IMAGE_LOWERCASE_NAME }}:${{ env.IMAGE_TAG }} cache-from: type=gha - cache-to: type=gha,mode=max \ No newline at end of file diff --git a/README.md b/README.md index 4c1010e..67095cd 100644 --- a/README.md +++ b/README.md @@ -58,6 +58,8 @@ Having Fun with Federated Learning. - ***FedOpt*** -- [Adaptive Federated Optimization](https://arxiv.org/abs/2003.00295) (ICLR'21) - ***Elastic Aggregation*** -- [Elastic Aggregation for Federated Optimization](https://openaccess.thecvf.com/content/CVPR2023/html/Chen_Elastic_Aggregation_for_Federated_Optimization_CVPR_2023_paper.html) (CVPR'23) + +- ***FedFed*** -- [FedFed: Feature Distillation against Data Heterogeneity in Federated Learning](http://arxiv.org/abs/2310.05077) (NIPS'23)
@@ -239,22 +241,22 @@ parallel: ``` ### Manually Create `Ray` Cluster (Optional) A `Ray` cluster would be created implicitly everytime you run experiment in parallel mode. -Or you can create it manually to avoid creating and destroying cluster every time you run experiment. +Or you can create it manually by the command shown below to avoid creating and destroying cluster every time you run experiment. +```shell +ray start --head [OPTIONS] +``` +👀 **NOTE:** You need to keep `num_cpus: null` and `num_gpus: null` in your config file for connecting to a existing `Ray` cluster. ```yaml # your_config_file.yml # Connect to an existing Ray cluster in localhost. mode: parallel parallel: - ray_cluster_addr: null + ... num_gpus: null num_cpus: null - ... ... ``` -```shell -ray start --head [OPTIONS] -``` -👀 **NOTE:** You need to keep `num_cpus: null` and `num_gpus: null` in your config file for connecting to a existing `Ray` cluster. + diff --git a/src/client/fedfed.py b/src/client/fedfed.py new file mode 100644 index 0000000..bdc7d39 --- /dev/null +++ b/src/client/fedfed.py @@ -0,0 +1,217 @@ +import gc +from copy import deepcopy +from typing import Any + +import numpy as np +import torch +import torch.nn.functional as F + +from src.client.fedavg import FedAvgClient +from src.utils.tools import trainable_params + + +class FedFedClient(FedAvgClient): + def __init__(self, VAE_cls, VAE_optimizer_cls, **commons): + super().__init__(**commons) + self.VAE: torch.nn.Module = VAE_cls(self.args).to(self.device) + self.VAE_optimizer: torch.optim.Optimizer = VAE_optimizer_cls( + params=trainable_params(self.VAE) + ) + self.offset_ori_dataset = len(self.dataset) + self.distilling = True + + def set_parameters(self, package: dict[str, Any]): + self.distilling = package.get("distilling", False) + super().set_parameters(package) + if self.distilling: + self.VAE.load_state_dict(package["VAE_regular_params"], strict=False) + self.VAE.load_state_dict(package["VAE_personal_params"], strict=False) + self.VAE_optimizer.load_state_dict(package["VAE_optimizer_state"]) + + def load_data_indices(self): + if self.distilling: + self.trainset.indices = self.data_indices[self.client_id]["train"] + else: + idxs_shared = np.random.choice( + len(self.dataset) - self.offset_ori_dataset, + len(self.data_indices[self.client_id]["train"]), + replace=False, + ) + self.trainset.indices = np.concatenate( + [self.data_indices[self.client_id]["train"] + idxs_shared] + ) + self.valset.indices = self.data_indices[self.client_id]["val"] + self.testset.indices = self.data_indices[self.client_id]["test"] + + def train_VAE(self, package: dict[str, Any]): + self.set_parameters(package) + self.model.train() + self.dataset.train() + for _ in range(self.args.fedfed.VAE_train_local_epoch): + self.VAE.eval() + for x, y in self.trainloader: + if len(y) <= 1: + continue + x, y = x.to(self.device), y.to(self.device) + x_mixed, y_ori, y_rand, lamda = mixup_data( + x, y, self.args.fedfed.VAE_alpha + ) + + logits = self.model(x_mixed) + loss_classifier = lamda * F.cross_entropy(logits, y_ori) + ( + 1 - lamda + ) * F.cross_entropy(logits, y_rand) + self.optimizer.zero_grad() + loss_classifier.backward() + self.optimizer.step() + + self.VAE.train() + for x, y in self.trainloader: + if len(y) <= 1: + continue + x, y = x.to(self.device), y.to(self.device) + + robust, mu, logvar = self.VAE(x) + + loss_VAE = (self.args.fedfed.VAE_re * F.mse_loss(robust, x)) + ( + self.args.fedfed.VAE_kl + * (-0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())) + / (self.args.fedfed.VAE_batch_size * 3 * self.VAE.feature_length) + ) + self.VAE_optimizer.zero_grad() + loss_VAE.backward() + self.VAE_optimizer.step() + + for x, y in self.trainloader: + if len(y) <= 1: + continue + x, y = x.to(self.device), y.to(self.device) + batch_size = x.shape[0] + robust, mu, logvar = self.VAE(x) + sensitive = x - robust + sensitive_protected1 = self.VAE.add_noise( + sensitive, + self.args.fedfed.VAE_noise_mean, + self.args.fedfed.VAE_noise_std1, + ) + sensitive_protected2 = self.VAE.add_noise( + sensitive, + self.args.fedfed.VAE_noise_mean, + self.args.fedfed.VAE_noise_std2, + ) + data = torch.cat([sensitive_protected1, sensitive_protected2, x]) + logits = self.model(data) + + loss_features_sensitive_protected = F.cross_entropy( + logits[: batch_size * 2], y.repeat(2) + ) + loss_x = F.cross_entropy(logits[batch_size * 2 :], y) + loss_mse = F.mse_loss(robust, x) + loss_kl = ( + -0.5 + * torch.sum(1 + logvar - mu.pow(2) - logvar.exp()) + / (self.args.fedfed.VAE_batch_size * 3 * self.VAE.feature_length) + ) + + loss = ( + self.args.fedfed.VAE_re * loss_mse + + self.args.fedfed.VAE_kl * loss_kl + + self.args.fedfed.VAE_ce * loss_features_sensitive_protected + + self.args.fedfed.VAE_x_ce * loss_x + ) + + self.VAE_optimizer.zero_grad() + self.optimizer.zero_grad() + loss.backward() + self.VAE_optimizer.step() + self.optimizer.step() + + VAE_regular_params, VAE_personal_params = {}, {} + for key, param in self.VAE.state_dict(keep_vars=True).items(): + if param.requires_grad: + VAE_regular_params[key] = param.detach().cpu().clone() + else: + VAE_personal_params[key] = param.detach().cpu().clone() + + _, regular_keys = trainable_params(self.model, requires_name=True) + model_params = self.model.state_dict(keep_vars=True) + return dict( + weight=len(self.trainset), + regular_model_params={ + key: model_params[key].detach().clone().cpu() for key in regular_keys + }, + personal_model_params={ + key: param.detach().clone().cpu() + for key, param in model_params.items() + if (not param.requires_grad) or (key in self.personal_params_name) + }, + VAE_regular_params=VAE_regular_params, + VAE_personal_params=VAE_personal_params, + optimizer_state=deepcopy(self.optimizer.state_dict()), + VAE_optimizer_state=deepcopy(self.VAE_optimizer.state_dict()), + ) + + @torch.no_grad + def generate_shared_data(self, package: dict[str, Any]): + self.set_parameters(package) + self.dataset.eval() + self.VAE.eval() + + data1 = [] + data2 = [] + targets = [] + for x, y in self.trainloader: + x, y = x.to(self.device), y.to(self.device) + + robust, _, _ = self.VAE(x) + sensitive = x - robust + data1.append( + self.VAE.add_noise( + sensitive, + self.args.fedfed.VAE_noise_mean, + self.args.fedfed.VAE_noise_std1, + ) + ) + data2.append( + self.VAE.add_noise( + sensitive, + self.args.fedfed.VAE_noise_mean, + self.args.fedfed.VAE_noise_std2, + ) + ) + targets.append(y) + + data1 = torch.cat(data1).float().cpu() + data2 = torch.cat(data2).float().cpu() + targets = torch.cat(targets).long().cpu() + + return dict(data1=data1, data2=data2, targets=targets) + + def accept_global_shared_data(self, package: dict[str, Any]): + # avoid loading multiple times + # only trigger once per worker (worker != client) + if self.distilling: + self.distilling = False + + # regular training doesn't need VAE + del self.VAE, self.VAE_optimizer + gc.collect() + torch.cuda.empty_cache() + + self.dataset.data = torch.cat( + [self.dataset.data, package["data1"], package["data2"]] + ) + self.dataset.targets = torch.cat( + [self.dataset.targets, package["targets"], package["targets"]] + ) + + +def mixup_data(x: torch.Tensor, y: torch.Tensor, alpha: float): + if alpha > 0: + lamda = np.random.beta(alpha, alpha) + else: + lamda = 1.0 + + shfl_idxs = np.random.permutation(x.shape[0]) + x_mixed = lamda * x + (1 - lamda) * x[shfl_idxs, :] + return x_mixed, y, y[shfl_idxs], lamda diff --git a/src/client/fedprox.py b/src/client/fedprox.py index a793f2c..4885111 100644 --- a/src/client/fedprox.py +++ b/src/client/fedprox.py @@ -23,6 +23,6 @@ def fit(self): for w, w_t in zip(trainable_params(self.model), global_params): w.grad.data += self.args.fedprox.mu * (w.data - w_t.data) self.optimizer.step() - + if self.lr_scheduler is not None: self.lr_scheduler.step() diff --git a/src/server/fedavg.py b/src/server/fedavg.py index e908c0f..0eb453a 100644 --- a/src/server/fedavg.py +++ b/src/server/fedavg.py @@ -117,8 +117,13 @@ def __init__( self.clients_optimizer_state = {i: {} for i in range(self.client_num)} self.clients_lr_scheduler_state = {i: {} for i in range(self.client_num)} - model_params_file_path = str((FLBENCH_ROOT / self.args.common.external_model_params_file).absolute()) - if os.path.isfile(model_params_file_path) and model_params_file_path.find(".pt") != -1: + model_params_file_path = str( + (FLBENCH_ROOT / self.args.common.external_model_params_file).absolute() + ) + if ( + os.path.isfile(model_params_file_path) + and model_params_file_path.find(".pt") != -1 + ): self.global_model_params = torch.load( model_params_file_path, map_location="cpu" ) @@ -213,8 +218,7 @@ def __init__( self.tensorboard = SummaryWriter(log_dir=self.output_dir) self.tensorboard.add_text( - "Experimental Arguments", - f"
{self.args}
", + "Experimental Arguments", f"
{self.args}
" ) # init trainer self.trainer: FLbenchTrainer @@ -683,7 +687,9 @@ def run(self): if self.args.common.visible == "tensorboard": for epoch, results in all_test_results.items(): self.tensorboard.add_text( - "Test Results", text_string=f"
{results}
", global_step=epoch + "Test Results", + text_string=f"
{results}
", + global_step=epoch, ) if self.args.common.check_convergence: diff --git a/src/server/fedfed.py b/src/server/fedfed.py new file mode 100644 index 0000000..6568a7f --- /dev/null +++ b/src/server/fedfed.py @@ -0,0 +1,309 @@ +import math +import random +from argparse import ArgumentParser, Namespace +from collections import OrderedDict +from copy import deepcopy +from functools import partial + +import numpy as np +import torch +import torch.nn as nn +from rich.progress import track + +from src.server.fedavg import FedAvgServer +from src.client.fedfed import FedFedClient +from src.utils.constants import DATA_SHAPE +from src.utils.tools import NestedNamespace, trainable_params + + +def get_fedfed_args(arg_list=None) -> Namespace: + parser = ArgumentParser() + parser.add_argument("--VAE_train_global_epoch", type=int, default=15) + parser.add_argument("--VAE_train_local_epoch", type=int, default=1) + parser.add_argument("--VAE_lr", type=float, default=1e-3) + parser.add_argument("--VAE_weight_decay", type=float, default=1e-6) + parser.add_argument("--VAE_alpha", type=float, default=2.0) + parser.add_argument("--VAE_noise_mean", type=float, default=0) + parser.add_argument("--VAE_noise_std1", type=float, default=0.15) + parser.add_argument("--VAE_noise_std2", type=float, default=0.25) + parser.add_argument("--VAE_re", type=float, default=5.0) + parser.add_argument("--VAE_x_ce", type=float, default=0.4) + parser.add_argument("--VAE_kl", type=float, default=0.005) + parser.add_argument("--VAE_ce", type=float, default=2.0) + parser.add_argument("--VAE_batch_size", type=int, default=64) + parser.add_argument("--VAE_block_depth", type=int, default=32) + parser.add_argument( + "--VAE_noise_type", + type=str, + choices=["laplace", "gaussian"], + default="gaussian", + ) + return parser.parse_args(arg_list) + + +class FedFedServer(FedAvgServer): + def __init__( + self, + args: NestedNamespace, + algo: str = "FedFed", + unique_model=False, + use_fedavg_client_cls=False, + return_diff=False, + ): + super().__init__(args, algo, unique_model, use_fedavg_client_cls, return_diff) + dummy_VAE_model = VAE(self.args) + VAE_optimizer_cls = partial( + torch.optim.AdamW, + lr=self.args.fedfed.VAE_lr, + weight_decay=self.args.fedfed.VAE_weight_decay, + ) + dummy_VAE_optimizer = VAE_optimizer_cls( + params=trainable_params(dummy_VAE_model) + ) + self.init_trainer( + FedFedClient, VAE_cls=VAE, VAE_optimizer_cls=VAE_optimizer_cls + ) + params, keys = trainable_params( + dummy_VAE_model, detach=True, requires_name=True + ) + init_VAE_personal_params = OrderedDict( + (key, param) + for key, param in dummy_VAE_model.state_dict(keep_vars=True).items() + if not param.requires_grad + ) + self.global_VAE_params = OrderedDict(zip(keys, params)) + self.client_VAE_personal_params = { + i: deepcopy(init_VAE_personal_params) for i in self.train_clients + } + self.client_VAE_optimizer_states = { + i: deepcopy(dummy_VAE_optimizer.state_dict()) for i in self.train_clients + } + del dummy_VAE_model, dummy_VAE_optimizer + + self.feature_distill() + + def feature_distill(self): + """Train VAE, generate shared data, distribute shared data""" + + def _package_VAE(client_id: int): + server_package = self.package(client_id) + server_package["distilling"] = True + server_package["VAE_regular_params"] = self.global_VAE_params + server_package["VAE_personal_params"] = self.client_VAE_personal_params.get( + client_id + ) + server_package["VAE_optimizer_state"] = ( + self.client_VAE_optimizer_states.get(client_id) + ) + return server_package + + num_join = max(1, int(self.args.common.join_ratio * len(self.train_clients))) + for i in track( + range(self.args.fedfed.VAE_train_global_epoch), + description="[green bold]Training VAE...", + ): + selected_clients = random.sample(self.train_clients, num_join) + client_packages = self.trainer.exec( + func_name="train_VAE", + clients=selected_clients, + package_func=_package_VAE, + ) + for client_id, package in client_packages.items(): + self.clients_personal_model_params[client_id].update( + package["personal_model_params"] + ) + self.clients_optimizer_state[client_id].update( + package["optimizer_state"] + ) + self.client_VAE_personal_params[client_id] = package[ + "VAE_personal_params" + ] + self.client_VAE_optimizer_states[client_id] = package[ + "VAE_optimizer_state" + ] + super().aggregate(client_packages) + + # aggregate client VAEs + weights = torch.tensor( + [package["weight"] for package in client_packages.values()], + dtype=torch.float, + ) + weights /= weights.sum() + client_VAE_regular_params = [ + list(package["VAE_regular_params"].values()) + for package in client_packages.values() + ] + for global_param, zipped_new_params in zip( + self.global_VAE_params.values(), zip(*client_VAE_regular_params) + ): + global_param.data = torch.sum( + torch.stack(zipped_new_params, dim=-1) * weights, dim=-1 + ) + + # gather client performance-sensitive data + client_packages = self.trainer.exec( + func_name="generate_shared_data", + clients=self.train_clients, + package_func=_package_VAE, + ) + data1, data2, targets = [], [], [] + for package in client_packages.values(): + data1.append(package["data1"]) + data2.append(package["data2"]) + targets.append(package["targets"]) + + global_shared_data1 = torch.cat(data1) + global_shared_data2 = torch.cat(data2) + global_shared_targets = torch.cat(targets) + + # distribute global shared + def _package_distribute_data(client_id: int): + nonlocal global_shared_data1, global_shared_data2, global_shared_targets + return dict( + client_id=client_id, + data1=global_shared_data1, + data2=global_shared_data2, + targets=global_shared_targets, + ) + + self.trainer.exec( + func_name="accept_global_shared_data", + clients=self.train_clients, + package_func=_package_distribute_data, + ) + + +# Modified from the official codes +class VAE(nn.Module): + def __init__(self, args): + super(VAE, self).__init__() + + class ResidualBlock(nn.Module): + def __init__(self, in_channels, out_channels=None): + super(ResidualBlock, self).__init__() + if out_channels is None: + out_channels = in_channels + layers = [ + nn.LeakyReLU(), + nn.Conv2d( + in_channels, out_channels, kernel_size=3, stride=1, padding=1 + ), + nn.BatchNorm2d(out_channels), + nn.LeakyReLU(), + nn.Conv2d( + out_channels, out_channels, kernel_size=1, stride=1, padding=0 + ), + ] + self.block = nn.Sequential(*layers) + + def forward(self, x): + return x + self.block(x) + + self.args = deepcopy(args) + img_depth = DATA_SHAPE[self.args.common.dataset][0] + img_shape = DATA_SHAPE[self.args.common.dataset][:-1] + + dummy_input = torch.randn(2, *DATA_SHAPE[self.args.common.dataset]) + self.encoder = nn.Sequential( + nn.Conv2d( + img_depth, + self.args.fedfed.VAE_block_depth // 2, + kernel_size=4, + stride=2, + padding=1, + bias=False, + ), + nn.BatchNorm2d(self.args.fedfed.VAE_block_depth // 2), + nn.ReLU(), + nn.Conv2d( + self.args.fedfed.VAE_block_depth // 2, + self.args.fedfed.VAE_block_depth, + kernel_size=4, + stride=2, + padding=1, + bias=False, + ), + nn.BatchNorm2d(self.args.fedfed.VAE_block_depth), + nn.ReLU(), + ResidualBlock(self.args.fedfed.VAE_block_depth), + nn.BatchNorm2d(self.args.fedfed.VAE_block_depth), + ResidualBlock(self.args.fedfed.VAE_block_depth), + ) + with torch.no_grad(): + dummy_feature = self.encoder(dummy_input) + self.feature_length = dummy_feature.flatten(start_dim=1).shape[-1] + self.feature_side = int( + math.sqrt(self.feature_length // self.args.fedfed.VAE_block_depth) + ) + + self.decoder = nn.Sequential( + ResidualBlock(self.args.fedfed.VAE_block_depth), + nn.BatchNorm2d(self.args.fedfed.VAE_block_depth), + ResidualBlock(self.args.fedfed.VAE_block_depth), + nn.BatchNorm2d(self.args.fedfed.VAE_block_depth), + nn.ConvTranspose2d( + self.args.fedfed.VAE_block_depth, + self.args.fedfed.VAE_block_depth // 2, + kernel_size=4, + stride=2, + padding=1, + bias=False, + ), + nn.BatchNorm2d(self.args.fedfed.VAE_block_depth // 2), + nn.LeakyReLU(), # really confused me here + # in the offcial codes, they use Tanh() right after LeakyReLU() what??? + nn.Tanh(), + # BTW, FedFed's codes of beta VAE is hugely different from other reproductions, + # such as https://github.com/AntixK/PyTorch-VAE/blob/master/models/beta_vae.py + nn.ConvTranspose2d( + self.args.fedfed.VAE_block_depth // 2, + img_depth, + kernel_size=4, + stride=2, + padding=1, + bias=False, + ), + nn.BatchNorm2d(img_depth), + nn.Sigmoid(), + ) + + self.fc_mu = nn.Linear(self.feature_length, self.feature_length) + self.fc_logvar = nn.Linear(self.feature_length, self.feature_length) + self.decoder_input = nn.Linear(self.feature_length, self.feature_length) + + def add_noise(self, data: torch.Tensor, mean, std): + if self.args.fedfed.VAE_noise_type == "gaussian": + noise = torch.normal( + mean=mean, std=std, size=data.shape, device=data.device + ) + if self.args.fedfed.VAE_noise_type == "laplace": + noise = torch.tensor( + np.random.laplace(loc=mean, scale=std, size=data.shape), + device=data.device, + ) + return data + noise + + def encode(self, x): + x = self.encoder(x).flatten(start_dim=1, end_dim=-1) + return self.fc_mu(x), self.fc_logvar(x) + + def reparameterize(self, mu: torch.Tensor, logvar: torch.Tensor): + if self.training: + std = torch.exp(0.5 * logvar) + eps = torch.randn_like(std, device=std.device) + return eps * std + mu + else: + return mu + + def decode(self, z): + result = self.decoder_input(z) + result = result.view( + -1, self.args.fedfed.VAE_block_depth, self.feature_side, self.feature_side + ) + return self.decoder(result) + + def forward(self, x): + mu, logvar = self.encode(x) + z = self.reparameterize(mu, logvar) + robust = self.decode(z) + return robust, mu, logvar diff --git a/src/server/local.py b/src/server/local.py index 8a4b6b1..13c09c3 100644 --- a/src/server/local.py +++ b/src/server/local.py @@ -16,5 +16,9 @@ def __init__( def train_one_round(self): client_packages = self.trainer.train() for client_id, package in client_packages.items(): - self.clients_personal_model_params[client_id].update(package["regular_model_params"]) - self.clients_personal_model_params[client_id].update(package["personal_model_params"]) \ No newline at end of file + self.clients_personal_model_params[client_id].update( + package["regular_model_params"] + ) + self.clients_personal_model_params[client_id].update( + package["personal_model_params"] + ) diff --git a/src/utils/metrics.py b/src/utils/metrics.py index d7de78f..e1608b6 100644 --- a/src/utils/metrics.py +++ b/src/utils/metrics.py @@ -31,11 +31,10 @@ def _calculate(self, metric, **kwargs): @property def loss(self): - try: - loss = self._loss / len(self._targets) - except ZeroDivisionError: + if len(self._targets) > 0: + return self._loss / len(self._targets) + else: return 0 - return loss @property def macro_precision(self): diff --git a/src/utils/models.py b/src/utils/models.py index 2e67f0a..92a7552 100644 --- a/src/utils/models.py +++ b/src/utils/models.py @@ -1,7 +1,6 @@ -import json from functools import partial from collections import OrderedDict -from typing import List, Optional +from typing import Optional import torch import torch.nn as nn @@ -9,7 +8,7 @@ import torchvision.models as models from torch import Tensor -from .constants import NUM_CLASSES, INPUT_CHANNELS, FLBENCH_ROOT +from src.utils.constants import DATA_SHAPE, NUM_CLASSES, INPUT_CHANNELS class DecoupledModel(nn.Module): @@ -86,25 +85,25 @@ def get_all_features(self, x: Tensor) -> Optional[list[Tensor]]: # CNN used in FedAvg class FedAvgCNN(DecoupledModel): + feature_length = { + "mnist": 1024, + "medmnistS": 1024, + "medmnistC": 1024, + "medmnistA": 1024, + "covid19": 196736, + "fmnist": 1024, + "emnist": 1024, + "femnist": 1, + "cifar10": 1600, + "cinic10": 1600, + "cifar100": 1600, + "tiny_imagenet": 3200, + "celeba": 133824, + "svhn": 1600, + "usps": 800, + } def __init__(self, dataset: str): super(FedAvgCNN, self).__init__() - features_length = { - "mnist": 1024, - "medmnistS": 1024, - "medmnistC": 1024, - "medmnistA": 1024, - "covid19": 196736, - "fmnist": 1024, - "emnist": 1024, - "femnist": 1, - "cifar10": 1600, - "cinic10": 1600, - "cifar100": 1600, - "tiny_imagenet": 3200, - "celeba": 133824, - "svhn": 1600, - "usps": 800, - } self.base = nn.Sequential( OrderedDict( conv1=nn.Conv2d(INPUT_CHANNELS[dataset], 32, 5), @@ -114,7 +113,7 @@ def __init__(self, dataset: str): activation2=nn.ReLU(), pool2=nn.MaxPool2d(2), flatten=nn.Flatten(), - fc1=nn.Linear(features_length[dataset], 512), + fc1=nn.Linear(self.feature_length[dataset], 512), ) ) self.classifier = nn.Linear(512, NUM_CLASSES[dataset]) @@ -124,25 +123,25 @@ def forward(self, x): class LeNet5(DecoupledModel): + feature_length = { + "mnist": 256, + "medmnistS": 256, + "medmnistC": 256, + "medmnistA": 256, + "covid19": 49184, + "fmnist": 256, + "emnist": 256, + "femnist": 256, + "cifar10": 400, + "cinic10": 400, + "svhn": 400, + "cifar100": 400, + "celeba": 33456, + "usps": 200, + "tiny_imagenet": 2704, + } def __init__(self, dataset: str) -> None: super(LeNet5, self).__init__() - feature_length = { - "mnist": 256, - "medmnistS": 256, - "medmnistC": 256, - "medmnistA": 256, - "covid19": 49184, - "fmnist": 256, - "emnist": 256, - "femnist": 256, - "cifar10": 400, - "cinic10": 400, - "svhn": 400, - "cifar100": 400, - "celeba": 33456, - "usps": 200, - "tiny_imagenet": 2704, - } self.base = nn.Sequential( OrderedDict( conv1=nn.Conv2d(INPUT_CHANNELS[dataset], 6, 5), @@ -154,7 +153,7 @@ def __init__(self, dataset: str) -> None: activation2=nn.ReLU(), pool2=nn.MaxPool2d(2), flatten=nn.Flatten(), - fc1=nn.Linear(feature_length[dataset], 120), + fc1=nn.Linear(self.feature_length[dataset], 120), activation3=nn.ReLU(), fc2=nn.Linear(120, 84), ) @@ -167,34 +166,25 @@ def forward(self, x): class TwoNN(DecoupledModel): + feature_length = { + "mnist": 784, + "medmnistS": 784, + "medmnistC": 784, + "medmnistA": 784, + "fmnist": 784, + "emnist": 784, + "femnist": 784, + "cifar10": 3072, + "cinic10": 3072, + "svhn": 3072, + "cifar100": 3072, + "usps": 1536, + "synthetic": DATA_SHAPE["synthetic"], + } def __init__(self, dataset): super(TwoNN, self).__init__() - - def get_synthetic_dimension(): - try: - with open(FLBENCH_ROOT / "data" / "synthetic" / "args.json", "r") as f: - metadata = json.load(f) - return metadata["dimension"] - except: - return 0 - - features_length = { - "mnist": 784, - "medmnistS": 784, - "medmnistC": 784, - "medmnistA": 784, - "fmnist": 784, - "emnist": 784, - "femnist": 784, - "cifar10": 3072, - "cinic10": 3072, - "svhn": 3072, - "cifar100": 3072, - "usps": 1536, - "synthetic": get_synthetic_dimension(), - } self.base = nn.Sequential( - nn.Linear(features_length[dataset], 200), + nn.Linear(self.feature_length[dataset], 200), nn.ReLU(inplace=True), nn.Linear(200, 200), nn.ReLU(inplace=True), diff --git a/src/utils/trainer.py b/src/utils/trainer.py index a9f892a..de38a16 100644 --- a/src/utils/trainer.py +++ b/src/utils/trainer.py @@ -1,4 +1,5 @@ from collections import OrderedDict, deque +from typing import Any, Callable import ray import ray.actor @@ -138,18 +139,28 @@ def _parallel_test(self, clients: list[int], results: dict): results[stage][split].update(metrics[stage][split]) def _serial_exec( - self, func_name: str, clients: list[int], package_func_name: str = "package" + self, + func_name: str, + clients: list[int], + package_func: Callable[[int], dict[str, Any]] = None, ): + if package_func is None: + package_func = getattr(self.server, "package") clients_package = OrderedDict() for client_id in clients: - server_package = getattr(self.server, package_func_name)(client_id) + server_package = package_func(client_id) package = getattr(self.worker, func_name)(server_package) clients_package[client_id] = package return clients_package def _parallel_exec( - self, func_name: str, clients: list[int], package_func_name: str = "package" + self, + func_name: str, + clients: list[int], + package_func: Callable[[int], dict[str, Any]] = None, ): + if package_func is None: + package_func = getattr(self.server, "package") clients_package = OrderedDict() i = 0 futures = [] @@ -157,9 +168,7 @@ def _parallel_exec( map = {} # {future: (client_id, worker_id)} while i < len(clients) or len(futures) > 0: while i < len(clients) and len(idle_workers) > 0: - server_package = ray.put( - getattr(self.server, package_func_name)(clients[i]) - ) + server_package = ray.put(package_func(clients[i])) worker_id = idle_workers.popleft() future = getattr(self.workers[worker_id], func_name).remote( server_package