Skip to content

Commit

Permalink
Merge pull request #333 from SMILELab-FL/fix-evaluate_demo-siqi
Browse files Browse the repository at this point in the history
fix standalone demo evaluate problem
  • Loading branch information
AgentDS authored Sep 19, 2023
2 parents ee91d3f + 3cf9c40 commit 17b6668
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 6 deletions.
2 changes: 1 addition & 1 deletion examples/standalone-mnist/launch_eg.sh
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
#!/bin/bash

python standalone.py --total_client 100 --com_round 10 --sample_ratio 0.1 --batch_size 128 --epochs 3 --lr 0.1
python standalone.py --total_clients 100 --com_round 10 --sample_ratio 0.1 --batch_size 128 --epochs 3 --lr 0.1
18 changes: 13 additions & 5 deletions examples/standalone-mnist/standalone.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@

# configuration
parser = argparse.ArgumentParser(description="Standalone training example")
parser.add_argument("--total_client", type=int, default=100)
parser.add_argument("--total_clients", type=int, default=100)
parser.add_argument("--com_round", type=int)

parser.add_argument("--sample_ratio", type=float)
Expand All @@ -34,18 +34,26 @@

args = parser.parse_args()

model =MLP(784, 10)
model = MLP(784, 10)

# server
handler = SyncServerHandler(model, args.com_round, args.total_clients, args.sample_ratio)
handler = SyncServerHandler(
model, args.com_round, args.total_clients, args.sample_ratio
)

# client
trainer = SGDSerialClientTrainer(model, args.total_client, cuda=True)
dataset = PathologicalMNIST(root='../../datasets/mnist/', path="../../datasets/mnist/", num_clients=args.total_client)
trainer = SGDSerialClientTrainer(model, args.total_clients, cuda=True)
dataset = PathologicalMNIST(
root="../../datasets/mnist/",
path="../../datasets/mnist/",
num_clients=args.total_clients,
)
dataset.preprocess()

trainer.setup_dataset(dataset)
trainer.setup_optim(args.epochs, args.batch_size, args.lr)

handler.setup_dataset(dataset)
# main
pipeline = StandalonePipeline(handler, trainer)
pipeline.main()

0 comments on commit 17b6668

Please sign in to comment.