Skip to content

Commit

Permalink
🌈 style(text): Rename some variables
Browse files Browse the repository at this point in the history
Rename test_gap to test_interval

Rename test_flag in client classes to testing

Format some codes by the new black
  • Loading branch information
KarhouTam committed Apr 2, 2024
1 parent e613cda commit e28f200
Show file tree
Hide file tree
Showing 6 changed files with 20 additions and 20 deletions.
8 changes: 5 additions & 3 deletions src/client/fedavg.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,8 @@ def __init__(
except:
raise FileNotFoundError(f"Please partition {args.dataset} first.")

self.data_indices: List[List[int]] = partition["data_indices"]
# [0: {"train": [...], "val": [...], "test": [...]}, ...]
self.data_indices: List[Dict[str, List[int]]] = partition["data_indices"]

# --------- you can define your custom data preprocessing strategy here ------------
test_data_transform = transforms.Compose(
Expand Down Expand Up @@ -311,7 +312,7 @@ def test(
Returns:
Dict[str, Dict[str, Metrics]]: the evalutaion metrics stats.
"""
self.test_flag = True
self.testing = True
self.client_id = client_id
self.load_dataset()
self.set_parameters(new_parameters)
Expand All @@ -327,7 +328,8 @@ def test(
self.finetune()
results["after"] = self.evaluate()
self.model.load_state_dict(frz_params_dict)
self.test_flag = False

self.testing = False
return results

def finetune(self):
Expand Down
4 changes: 2 additions & 2 deletions src/server/adcol.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,12 +112,12 @@ def train_one_round(self):

def train_and_test_discriminator(self):
self.generate_client_index()
if (self.current_epoch + 1) % self.args.test_gap == 0:
if (self.current_epoch + 1) % self.args.test_interval == 0:
acc_before = self.test_discriminator()

self.train_discriminator()

if (self.current_epoch + 1) % self.args.test_gap == 0:
if (self.current_epoch + 1) % self.args.test_interval == 0:
acc_after = self.test_discriminator()
if (self.current_epoch + 1) % self.args.verbose_gap == 0:
self.logger.log(
Expand Down
2 changes: 1 addition & 1 deletion src/server/fedap.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ def train(self):
if (E + 1) % self.args.verbose_gap == 0:
self.logger.log(" " * 30, f"TRAINING EPOCH: {E + 1}", " " * 30)

if (E + 1) % self.args.test_gap == 0:
if (E + 1) % self.args.test_interval == 0:
self.test()

self.selected_clients = self.client_sample_stream[E]
Expand Down
22 changes: 10 additions & 12 deletions src/server/fedavg.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def get_fedavg_argparser() -> ArgumentParser:
parser.add_argument("-ge", "--global_epoch", type=int, default=100)
parser.add_argument("-le", "--local_epoch", type=int, default=5)
parser.add_argument("-fe", "--finetune_epoch", type=int, default=0)
parser.add_argument("-tg", "--test_gap", type=int, default=100)
parser.add_argument("-ti", "--test_interval", type=int, default=100)
parser.add_argument("--eval_test", type=int, default=1)
parser.add_argument("--eval_val", type=int, default=0)
parser.add_argument("--eval_train", type=int, default=0)
Expand Down Expand Up @@ -175,7 +175,7 @@ def __init__(
self.epoch_test = [
epoch
for epoch in range(0, self.args.global_epoch)
if (epoch + 1) % self.args.test_gap == 0
if (epoch + 1) % self.args.test_interval == 0
]
# For controlling behaviors of some specific methods while testing (not used by all methods)
self.test_flag = False
Expand Down Expand Up @@ -237,7 +237,7 @@ def train(self):
if (E + 1) % self.args.verbose_gap == 0:
self.logger.log("-" * 26, f"TRAINING EPOCH: {E + 1}", "-" * 26)

if (E + 1) % self.args.test_gap == 0:
if (E + 1) % self.args.test_interval == 0:
self.test()

self.selected_clients = self.client_sample_stream[E]
Expand All @@ -259,15 +259,13 @@ def train_one_round(self):
weight_cache = []
for client_id in self.selected_clients:
client_local_params = self.generate_client_params(client_id)
(
delta,
weight,
self.client_metrics[client_id][self.current_epoch],
) = self.trainer.train(
client_id=client_id,
local_epoch=self.clients_local_epoch[client_id],
new_parameters=client_local_params,
verbose=((self.current_epoch + 1) % self.args.verbose_gap) == 0,
(delta, weight, self.client_metrics[client_id][self.current_epoch]) = (
self.trainer.train(
client_id=client_id,
local_epoch=self.clients_local_epoch[client_id],
new_parameters=client_local_params,
verbose=((self.current_epoch + 1) % self.args.verbose_gap) == 0,
)
)

delta_cache.append(delta)
Expand Down
2 changes: 1 addition & 1 deletion src/server/metafed.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def train(self):
if (E + 1) % self.args.verbose_gap == 0:
self.logger.log("-" * 26, f"TRAINING EPOCH: {E + 1}", "-" * 26)

if (E + 1) % self.args.test_gap == 0:
if (E + 1) % self.args.test_interval == 0:
self.test()

client_params_cache = []
Expand Down
2 changes: 1 addition & 1 deletion src/server/pfedsim.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def train(self):
if (E + 1) % self.args.verbose_gap == 0:
self.logger.log(" " * 30, f"TRAINING EPOCH: {E + 1}", " " * 30)

if (E + 1) % self.args.test_gap == 0:
if (E + 1) % self.args.test_interval == 0:
self.test()

self.selected_clients = self.client_sample_stream[E]
Expand Down

0 comments on commit e28f200

Please sign in to comment.