Skip to content

Commit

Permalink
✨ feat(Args): Relax the request towards config file
Browse files Browse the repository at this point in the history
Also modify some default values in DEFAULT_COMMON_ARGS

Fix some links in README.md
  • Loading branch information
KarhouTam committed Apr 5, 2024
1 parent 89dc572 commit e132c8a
Show file tree
Hide file tree
Showing 5 changed files with 49 additions and 40 deletions.
20 changes: 11 additions & 9 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@ ALL classes of methods are inherited from `FedAvgServer` and `FedAvgClient`. If

### Step 1. Generate FL Dataset
```shell
# partition the CIFAR-10 according to Dir(0.1) for 100 clients
# Partition the CIFAR-10 according to Dir(0.1) for 100 clients
python generate_data.py -d cifar10 -a 0.1 -cn 100
```
About methods of generating federated dastaset, go check [`data/README.md`](data/#readme) for full details.
Expand All @@ -173,8 +173,8 @@ About methods of generating federated dastaset, go check [`data/README.md`](data
❗ Method name should be identical to the `.py` file name in `src/server`.

```
# run FedAvg on CIFAR-10 with default settings.
python main.py fedavg config/your_config.yml
# Run FedAvg with default settings.
python main.py fedavg
```


Expand All @@ -195,19 +195,21 @@ def get_fedprox_args(args_list=None) -> Namespace:
and you set
```yaml
# your_config.yml
...
fedprox:
mu: 0.1
```
in your config file. If you run
```shell
python main.py fedprox your_config.yml --mu 10 # fedprox.mu = 10
python main.py fedprox # fedprox.mu = 1
python main.py fedprox your_config.yml # fedprox.mu = 0.01
python main.py fedprox your_config.yml --mu 10 # fedprox.mu = 10
```


### Monitor 📈
1. Run `python -m visdom.server` on terminal.
2. Set `visible` as `True`.
2. Set `visible` as `true`.
3. Go check `localhost:8097` on your browser.

## Common Arguments 🔧
Expand Down Expand Up @@ -306,15 +308,15 @@ Medical Image Datasets

### New FL Method

- Inherit your method class from `FedAvgServer` and `FedAvgClient`, which are the base classes of FL-bench and ALL other method classes are inherited from them.
- Inherit your method class from [`FedAvgServer`](src/server/fedavg.py) and [`FedAvgClient`](src/client/fedavg.py), which are the base classes of FL-bench and ALL other method classes are inherited from them.

- Override the `train_one_round()` in `FedAvgServer` for your custom server-side process.
- Override the `train_one_round()` in [`FedAvgServer`](src/server/fedavg.py) for your custom server-side process.

- Override the `fit()` or `train()` in `FedAvgClient()` for your custom client-side process.
- Override the `fit()` or `train()` in [`FedAvgClient`](src/client/fedavg.py) for your custom client-side process.

### New Dataset

- Inherit your dataset class from [`BaseDataset` in `data/utils/datasets.py`](data/utils/dataset.py) and add your class in dict `DATASETS`.
- Inherit your dataset class from `BaseDataset` in [`data/utils/datasets.py`](data/utils/datasets.py) and add your class in dict `DATASETS`.

### New Model

Expand Down
18 changes: 8 additions & 10 deletions config/template.yml
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,15 @@ common:
straggler_min_local_epoch: 0
external_model_params_file: null
optimizer:
name: sgd # [sgd, adam, adamw, rmsprop, adagrad]
name: "sgd" # [sgd, adam, adamw, rmsprop, adagrad]
lr: 0.01
dampening: 0 # [sgd]
dampening: 0 # SGD
weight_decay: 0
momentum: 0 # SGD, RMSprop
alpha: 0.99 # RMSprop
nesterov: false # SGD
betas: [0.9, 0.999] # Adam, AdamW
momentum: 0 # SGD, RMSprop,
alpha: 0.99 # RMSprop,
nesterov: false # SGD,
betas: [0.9, 0.999] # Adam, AdamW,
amsgrad: false # Adam, AdamW

eval_test: true
eval_val: false
Expand All @@ -35,17 +36,14 @@ common:
save_metrics: true
viz_win_name: null
check_convergence: true


# You can set specific arguments for FL methods also
# FL-bench uses FL method arguments by args.<method>.<arg>
# e.g.,
# fedprox:
# mu: 0.01
#
#
# pfedsim:
# warmup_round: 0.7
# ...


# NOTE: For those unmentioned arguments, the default values are set in `get_<method>_args()` in `src/server/<method>.py`
15 changes: 10 additions & 5 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,19 @@


if __name__ == "__main__":
if len(sys.argv) < 3:
raise ValueError(
"No <method> or <config_file>. Run like `python main.py <method> <config_file_relative_path> [cli_method_args ...]`,\n e.g., python main.py fedavg config/template.yml`"
if len(sys.argv) < 2:
raise RuntimeError(
"No method is specified. Run like `python main.py <method> [config_file_relative_path] [cli_method_args ...]`,\n e.g., python main.py fedavg config/template.yml`"
)

method_name = sys.argv[1]
config_file_path = sys.argv[2]
cli_method_args = sys.argv[3:]

config_file_path = None
cli_method_args = []
if len(sys.argv) > 2:
config_file_path = sys.argv[2]
cli_method_args = sys.argv[3:]

try:
method_module = importlib.import_module(method_name)
except:
Expand Down
9 changes: 5 additions & 4 deletions src/utils/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,23 +12,24 @@
"seed": 42,
"model": "lenet5",
"join_ratio": 0.1,
"global_epoch": 10,
"local_epoch": 10,
"global_epoch": 100,
"local_epoch": 5,
"finetune_epoch": 0,
"batch_size": 32,
"test_interval": 10,
"test_interval": 100,
"straggler_ratio": 0,
"straggler_min_local_epoch": 0,
"external_model_params_file": None,
"optimizer": {
"name": "sgd", # [sgd, adam, adamw, rmsprop, adagrad]
"lr": 0.01,
"dampening": 0, # [sgd],
"dampening": 0, # SGD,
"weight_decay": 0,
"momentum": 0, # SGD, RMSprop,
"alpha": 0.99, # RMSprop,
"nesterov": False, # SGD,
"betas": [0.9, 0.999], # Adam, AdamW,
"amsgrad": False, # Adam, AdamW
},
"eval_test": True,
"eval_val": False,
Expand Down
27 changes: 15 additions & 12 deletions src/utils/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,18 +183,21 @@ def parse_args(
Returns:
Namespace: The merged arg namespace.
"""
try:
with open(Path(config_file_path).absolute()) as f:
config_file_args = yaml.safe_load(f)
final_args = NestedNamespace(config_file_args)
except:
raise FileNotFoundError(
f"Unrecongnized config file path: {Path(config_file_path).absolute()}."
)

common_args = DEFAULT_COMMON_ARGS
common_args.update(config_file_args["common"])
final_args.__setattr__("common", NestedNamespace(common_args))
final_args = NestedNamespace({"common": DEFAULT_COMMON_ARGS})
config_file_args = {}
if config_file_path is not None:
try:
with open(Path(config_file_path).absolute()) as f:
config_file_args = yaml.safe_load(f)
final_args = NestedNamespace(config_file_args)

common_args = DEFAULT_COMMON_ARGS
common_args.update(config_file_args["common"])
final_args.__setattr__("common", NestedNamespace(common_args))
except:
Warning(
f"Unrecongnized config file path: {Path(config_file_path).absolute()}. All common arguments are rolled back to their defaults."
)

if get_method_args_func is not None:
default_method_args = get_method_args_func([]).__dict__
Expand Down

0 comments on commit e132c8a

Please sign in to comment.