diff --git a/examples/standalone-mnist/launch_eg.sh b/examples/standalone-mnist/launch_eg.sh index 4630bafd..157ed961 100644 --- a/examples/standalone-mnist/launch_eg.sh +++ b/examples/standalone-mnist/launch_eg.sh @@ -1,3 +1,3 @@ #!/bin/bash -python standalone.py --total_client 100 --com_round 10 --sample_ratio 0.1 --batch_size 128 --epochs 3 --lr 0.1 +python standalone.py --total_clients 100 --com_round 10 --sample_ratio 0.1 --batch_size 128 --epochs 3 --lr 0.1 diff --git a/examples/standalone-mnist/standalone.py b/examples/standalone-mnist/standalone.py index 0c0bb5fc..719c5f45 100644 --- a/examples/standalone-mnist/standalone.py +++ b/examples/standalone-mnist/standalone.py @@ -24,7 +24,7 @@ # configuration parser = argparse.ArgumentParser(description="Standalone training example") -parser.add_argument("--total_client", type=int, default=100) +parser.add_argument("--total_clients", type=int, default=100) parser.add_argument("--com_round", type=int) parser.add_argument("--sample_ratio", type=float) @@ -34,18 +34,26 @@ args = parser.parse_args() -model =MLP(784, 10) +model = MLP(784, 10) # server -handler = SyncServerHandler(model, args.com_round, args.total_clients, args.sample_ratio) +handler = SyncServerHandler( + model, args.com_round, args.total_clients, args.sample_ratio +) # client -trainer = SGDSerialClientTrainer(model, args.total_client, cuda=True) -dataset = PathologicalMNIST(root='../../datasets/mnist/', path="../../datasets/mnist/", num_clients=args.total_client) +trainer = SGDSerialClientTrainer(model, args.total_clients, cuda=True) +dataset = PathologicalMNIST( + root="../../datasets/mnist/", + path="../../datasets/mnist/", + num_clients=args.total_clients, +) dataset.preprocess() + trainer.setup_dataset(dataset) trainer.setup_optim(args.epochs, args.batch_size, args.lr) +handler.setup_dataset(dataset) # main pipeline = StandalonePipeline(handler, trainer) pipeline.main()