From e28f200d5cdacfe92d5c9d006679b2d38b23e5e6 Mon Sep 17 00:00:00 2001 From: KarhouTam Date: Tue, 2 Apr 2024 16:36:50 +0800 Subject: [PATCH] =?UTF-8?q?=F0=9F=8C=88=20style(text):=20Rename=20some=20v?= =?UTF-8?q?ariables?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Rename test_gap to test_interval Rename test_flag in client classes to testing Format some codes by the new black --- src/client/fedavg.py | 8 +++++--- src/server/adcol.py | 4 ++-- src/server/fedap.py | 2 +- src/server/fedavg.py | 22 ++++++++++------------ src/server/metafed.py | 2 +- src/server/pfedsim.py | 2 +- 6 files changed, 20 insertions(+), 20 deletions(-) diff --git a/src/client/fedavg.py b/src/client/fedavg.py index b44fb7d..6bcc57f 100644 --- a/src/client/fedavg.py +++ b/src/client/fedavg.py @@ -38,7 +38,8 @@ def __init__( except: raise FileNotFoundError(f"Please partition {args.dataset} first.") - self.data_indices: List[List[int]] = partition["data_indices"] + # [0: {"train": [...], "val": [...], "test": [...]}, ...] + self.data_indices: List[Dict[str, List[int]]] = partition["data_indices"] # --------- you can define your custom data preprocessing strategy here ------------ test_data_transform = transforms.Compose( @@ -311,7 +312,7 @@ def test( Returns: Dict[str, Dict[str, Metrics]]: the evalutaion metrics stats. """ - self.test_flag = True + self.testing = True self.client_id = client_id self.load_dataset() self.set_parameters(new_parameters) @@ -327,7 +328,8 @@ def test( self.finetune() results["after"] = self.evaluate() self.model.load_state_dict(frz_params_dict) - self.test_flag = False + + self.testing = False return results def finetune(self): diff --git a/src/server/adcol.py b/src/server/adcol.py index 8be53c5..2ef7ce8 100644 --- a/src/server/adcol.py +++ b/src/server/adcol.py @@ -112,12 +112,12 @@ def train_one_round(self): def train_and_test_discriminator(self): self.generate_client_index() - if (self.current_epoch + 1) % self.args.test_gap == 0: + if (self.current_epoch + 1) % self.args.test_interval == 0: acc_before = self.test_discriminator() self.train_discriminator() - if (self.current_epoch + 1) % self.args.test_gap == 0: + if (self.current_epoch + 1) % self.args.test_interval == 0: acc_after = self.test_discriminator() if (self.current_epoch + 1) % self.args.verbose_gap == 0: self.logger.log( diff --git a/src/server/fedap.py b/src/server/fedap.py index 5374799..3c384ce 100644 --- a/src/server/fedap.py +++ b/src/server/fedap.py @@ -139,7 +139,7 @@ def train(self): if (E + 1) % self.args.verbose_gap == 0: self.logger.log(" " * 30, f"TRAINING EPOCH: {E + 1}", " " * 30) - if (E + 1) % self.args.test_gap == 0: + if (E + 1) % self.args.test_interval == 0: self.test() self.selected_clients = self.client_sample_stream[E] diff --git a/src/server/fedavg.py b/src/server/fedavg.py index e197f01..dc15f21 100644 --- a/src/server/fedavg.py +++ b/src/server/fedavg.py @@ -47,7 +47,7 @@ def get_fedavg_argparser() -> ArgumentParser: parser.add_argument("-ge", "--global_epoch", type=int, default=100) parser.add_argument("-le", "--local_epoch", type=int, default=5) parser.add_argument("-fe", "--finetune_epoch", type=int, default=0) - parser.add_argument("-tg", "--test_gap", type=int, default=100) + parser.add_argument("-ti", "--test_interval", type=int, default=100) parser.add_argument("--eval_test", type=int, default=1) parser.add_argument("--eval_val", type=int, default=0) parser.add_argument("--eval_train", type=int, default=0) @@ -175,7 +175,7 @@ def __init__( self.epoch_test = [ epoch for epoch in range(0, self.args.global_epoch) - if (epoch + 1) % self.args.test_gap == 0 + if (epoch + 1) % self.args.test_interval == 0 ] # For controlling behaviors of some specific methods while testing (not used by all methods) self.test_flag = False @@ -237,7 +237,7 @@ def train(self): if (E + 1) % self.args.verbose_gap == 0: self.logger.log("-" * 26, f"TRAINING EPOCH: {E + 1}", "-" * 26) - if (E + 1) % self.args.test_gap == 0: + if (E + 1) % self.args.test_interval == 0: self.test() self.selected_clients = self.client_sample_stream[E] @@ -259,15 +259,13 @@ def train_one_round(self): weight_cache = [] for client_id in self.selected_clients: client_local_params = self.generate_client_params(client_id) - ( - delta, - weight, - self.client_metrics[client_id][self.current_epoch], - ) = self.trainer.train( - client_id=client_id, - local_epoch=self.clients_local_epoch[client_id], - new_parameters=client_local_params, - verbose=((self.current_epoch + 1) % self.args.verbose_gap) == 0, + (delta, weight, self.client_metrics[client_id][self.current_epoch]) = ( + self.trainer.train( + client_id=client_id, + local_epoch=self.clients_local_epoch[client_id], + new_parameters=client_local_params, + verbose=((self.current_epoch + 1) % self.args.verbose_gap) == 0, + ) ) delta_cache.append(delta) diff --git a/src/server/metafed.py b/src/server/metafed.py index 36c944c..cf165ff 100644 --- a/src/server/metafed.py +++ b/src/server/metafed.py @@ -64,7 +64,7 @@ def train(self): if (E + 1) % self.args.verbose_gap == 0: self.logger.log("-" * 26, f"TRAINING EPOCH: {E + 1}", "-" * 26) - if (E + 1) % self.args.test_gap == 0: + if (E + 1) % self.args.test_interval == 0: self.test() client_params_cache = [] diff --git a/src/server/pfedsim.py b/src/server/pfedsim.py index d774375..6395966 100644 --- a/src/server/pfedsim.py +++ b/src/server/pfedsim.py @@ -66,7 +66,7 @@ def train(self): if (E + 1) % self.args.verbose_gap == 0: self.logger.log(" " * 30, f"TRAINING EPOCH: {E + 1}", " " * 30) - if (E + 1) % self.args.test_gap == 0: + if (E + 1) % self.args.test_interval == 0: self.test() self.selected_clients = self.client_sample_stream[E]