Skip to content

Commit

Permalink
Merge pull request #1152 from lemonviv/dev-postgresql
Browse files Browse the repository at this point in the history
Update the HFL example and README
  • Loading branch information
lzjpaul authored Mar 23, 2024
2 parents 6085ea9 + 2ab4ed6 commit c4496ed
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 5 deletions.
10 changes: 5 additions & 5 deletions examples/hfl/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ This example uses the Bank dataset and an MLP model in FL.

## Preparation

Go to the Conda environment that contains the Singa library, and run
Go to the Conda environment that contains the Singa library, and install the required libraries.

```bash
pip install -r requirements.txt
Expand All @@ -41,23 +41,23 @@ Download the bank dataset and split it into 3 partitions.
# 3. run the following command which:
# (1) splits the dataset into N subsets
# (2) splits each subsets into train set and test set (8:2)
python -m bank N
python -m bank 3
```

## Run the example

Run the server first (set the number of epochs to 3)
Run the server first (set the maximum number of epochs to 3 by the "-m" parameter)

```bash
python -m src.server -m 3 --num_clients 3
```

Then, start 3 clients in different terminal
Then, start 3 clients in different terminals (similarly set the maximum number of epochs to 3)

```bash
python -m src.client --model mlp --data bank -m 3 -i 0 -d non-iid
python -m src.client --model mlp --data bank -m 3 -i 1 -d non-iid
python -m src.client --model mlp --data bank -m 3 -i 2 -d non-iid
```

Finally, the server and clients finish the FL training.
Finally, the server and clients finish the FL training.
2 changes: 2 additions & 0 deletions examples/hfl/src/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
np_dtype = {"float16": np.float16, "float32": np.float32}
singa_dtype = {"float16": tensor.float16, "float32": tensor.float32}


class Client:
"""Client sends and receives protobuf messages.
Expand All @@ -63,6 +64,7 @@ def __init__(
Args:
global_rank (int, optional): The rank in training process. Defaults to 0.
Provided by the '-i' parameter (device_id) in the running script.
host (str, optional): Host ip address. Defaults to '127.0.0.1'.
port (str, optional): Port. Defaults to 1234.
"""
Expand Down
1 change: 1 addition & 0 deletions examples/hfl/src/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ def __start_rank_pairing(self) -> None:
"""Start pair each client to a global rank"""
for _ in range(self.num_clients):
conn, addr = self.sock.accept()
# rank is the global device_id when initializing the client
rank = utils.receive_int(conn)
self.conns[rank] = conn
self.addrs[rank] = addr
Expand Down

0 comments on commit c4496ed

Please sign in to comment.