Skip to content

Handy PyTorch implementation of Federated Learning (for your painless research)

License

Notifications You must be signed in to change notification settings

vaseline555/Federated-Learning-in-PyTorch

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

49 Commits
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Federated Learning in PyTorch

Implementations of various Federated Learning (FL) algorithms in PyTorch, especially for research purposes.

Implementation Details

Datasets

  • Supports all image classification datasets in torchvision.datasets.
  • Supports all text classification datasets in torchtext.datasets.
  • Supports all datasets in LEAF benchmark (NO need to prepare raw data manually)
  • Supports additional image classification datasets (TinyImageNet, CINIC10).
  • Supports additional text classification datasets (BeerReviews).
  • Supports tabular datasets (Heart, Adult, Cover).
  • Supports temporal dataset (GLEAM)
  • NOTE: don't bother to search raw files of datasets; the dataset can automatically be downloaded to the designated path by just passing its name!

Statistical Heterogeneity Simulations

  • IID (i.e., statistical homogeneity)
  • Unbalanced (i.e., sample counts heterogeneity)
  • Pathological Non-IID (McMahan et al., 2016)
  • Dirichlet distribution-based Non-IID (Hsu et al., 2019)
  • Pre-defined (for datasets having natural semantic separation, including LEAF benchmark (Caldas et al., 2018))

Models

Algorithms

Evaluation schemes

  • local: evaluate FL algorithm using holdout sets of (some/all) clients NOT participating in the current round. (i.e., evaluation of personalized federated learning setting)
  • global: evaluate FL algorithm using global holdout set located at the server. (ONLY available if the raw dataset supports pre-defined validation/test set).
  • both: evaluate FL algorithm using both local and global schemes.

Metrics

  • Top-1 Accuracy, Top-5 Accuracy, Precision, Recall, F1
  • Area under ROC, Area under PRC, Youden's J
  • Seq2Seq Accuracy
  • MSE, RMSE, MAE, MAPE
  • $R^2$, $D^2$

Requirements

  • See requirements.txt. (I recommend building an independent environment for this project, using e.g., Docker or conda)
  • When you install torchtext, please check the version compatibility with torch. (See official repository)
  • Plus, please install torch-related packages using one command provided by the official guide (See official installation guide); e.g., conda install pytorch==1.12.0 torchvision==0.13.0 torchaudio==0.12.0 torchtext==0.13.0 cudatoolkit=11.6 -c pytorch -c conda-forge

Configurations

  • See python3 main.py -h.

Example Commands

  • See shell files prepared in commands directory.

TODO

  • Support another model, especially lightweight ones for cross-device FL setting. (e.g., EdgeNeXt)
  • Support another structured dataset including temporal and tabular data, along with datasets suitable for cross-silo FL setting. (e.g., MedMNIST)
  • Add other popular FL algorithms including personalized FL algorithms (e.g., SuPerFed).
  • Attach benchmark results of sample commands.

Contact

Should you have any feedback, please create a thread in issue tab. Thank you :)