From 7b874fdb0e0221d3d7fb68fb599d162ec254fbb3 Mon Sep 17 00:00:00 2001 From: Zed <928255708@qq.com> Date: Fri, 23 Sep 2022 15:15:14 +0800 Subject: [PATCH] [revision] add tutorial --- fedlab/core/standalone.py | 2 +- requirements.txt | 1 + tutorial.ipynb | 142 ++++++++++++++++++++++++++++++++++++-- 3 files changed, 137 insertions(+), 8 deletions(-) diff --git a/fedlab/core/standalone.py b/fedlab/core/standalone.py index 7040daa5..04aee375 100644 --- a/fedlab/core/standalone.py +++ b/fedlab/core/standalone.py @@ -13,7 +13,7 @@ # limitations under the License. -class StandalonePipelineExample(object): +class StandalonePipeline(object): def __init__(self, handler, trainer): """Perform standalone simulation process. diff --git a/requirements.txt b/requirements.txt index 2a3f7255..5aa0be62 100644 --- a/requirements.txt +++ b/requirements.txt @@ -3,6 +3,7 @@ pandas pynvml tqdm scikit-learn +munch torch>=1.7.1 torchvision>=0.8.2 diff --git a/tutorial.ipynb b/tutorial.ipynb index 62fc6e48..e63aaeb0 100644 --- a/tutorial.ipynb +++ b/tutorial.ipynb @@ -4,7 +4,11 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "# Tutorials for FedLab users\n" + "# Tutorial for FedLab users\n", + "\n", + "This is a comprehensive tutorial for users who would like to know FedLab. FedLab is built on the top of [torch.distributed](torch.distributed) modules and provides the necessary modules for FL simulation, including communication, compression, model optimization, data partition, and other functional modules. FedLab users can build FL simulation environment with custom modules like playing with LEGO bricks. \n", + "\n", + "In this tutorial, we will further describe the architecture of FedLab and its usage. To put it simply, we introduce FedLab by implementing a vanilla federated learning algorithm FedAvg in this file." ] }, { @@ -12,13 +16,24 @@ "execution_count": null, "metadata": {}, "outputs": [], - "source": [] + "source": [ + "# configuration\n", + "from munch import Munch\n", + "args = Munch\n", + "\n", + "args.total_client = 100\n", + "args.alpha = 0.5\n", + "args.seed = 42\n", + "args.preprocess = True" + ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "## Prepare your dataset" + "## Prepare your dataset\n", + "\n", + "FedLab provide necessary module for uses to patition their datasets. Additionally, various implementation of datasets partition for federated learning are also availiable at the [URL](https://github.com/SMILELab-FL/FedLab/tree/master/fedlab/dataset)." ] }, { @@ -26,15 +41,121 @@ "execution_count": null, "metadata": {}, "outputs": [], - "source": [] + "source": [ + "# We provide a example usage of patitioned MNIST dataset\n", + "# Download raw MNIST dataset and partition them according to given configuration\n", + "\n", + "from torchvision import transforms\n", + "from fedlab.dataset.partitioned_mnist import PartitionedMNIST\n", + "\n", + "fed_mnist = PartitionedMNIST(root=\"./tests/data/mnist/\",\n", + " path=\"./tests/data/mnist/fedmnist/\",\n", + " num_clients=args.total_client,\n", + " partition=\"noniid-labeldir\",\n", + " dir_alpha=args.alpha,\n", + " seed=args.seed,\n", + " preprocess=args.preprocess,\n", + " download=True,\n", + " verbose=True,\n", + " transform=transforms.Compose(\n", + " [transforms.ToPILImage(), transforms.ToTensor()]))\n", + "\n", + "\n", + "dataset = fed_mnist.get_dataset(0) # get the 0-th client's dataset\n", + "dataloader = fed_mnist.get_dataloader(0, batch_size=128) # get the 0-th client's dataset loader with batch size 128" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 1. Client local training\n", + "\n", + "Client training procedure is implemented by class ClientTrainer in FedLab. We have built-in FedAvg implementation in FedLab." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# client\n", + "from fedlab.contrib.clients import SGDSerialTrainer, SGDClientTrainer\n", + "from fedlab.models.mlp import MLP\n", + "\n", + "model = MLP(784, 10)\n", + "\n", + "trainer = SGDSerialTrainer(model, args.total_client, cuda=True) # serial trainer\n", + "# trainer = SGDClientTrainer(model, cuda=True) # single trainer\n", + "\n", + "trainer.setup_dataset(fed_mnist)\n", + "trainer.setup_optim(args.epochs, args.batch_size, args.lr)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 2. Server global aggregation" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# server\n", + "from fedlab.contrib.servers.server import SyncServerHandler\n", + "\n", + "args.com_round = 10\n", + "args.sample_ratio = 0.1\n", + "\n", + "handler = SyncServerHandler(model=model, global_round=args.com_round, sample_ratio=args.sample_ratio)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Choose simulation mode\n", + "\n", + "We provide three basic simulation mode in FedLab. Depending on the needs of users.\n", + "\n", + "1. Choose Standalone mode to run the simulation with lowest resourch allocation.\n", + "2. Choose Cross-process mode to run the simulation with multi-machine or multi-gpus with faster calculation.\n", + "3. Chosse Hierachical mode to run the simulation across computer clusters." + ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "## Define clients" + "### Standalone\n", + "\n", + "We provide an example pipeline implementation. Please see [URL](https://github.com/SMILELab-FL/FedLab/blob/master/fedlab/core/standalone.py)" ] }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from fedlab.core.standalone import StandalonePipeline\n", + "\n", + "standalone = StandalonePipeline(handler=handler, trainer=trainer)\n", + "standalone.main()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, { "cell_type": "code", "execution_count": null, @@ -46,7 +167,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "## Define server" + "### Cross-process" ] }, { @@ -56,11 +177,18 @@ "outputs": [], "source": [] }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, { "cell_type": "markdown", "metadata": {}, "source": [ - "## Choose simulation mode" + "### Hierachical" ] }, {