Skip to content

Commit

Permalink
Merge pull request #1102 from NLGithubWP/add_train_mpi
Browse files Browse the repository at this point in the history
Add the training script for models using MPI
  • Loading branch information
lzjpaul authored Sep 23, 2023
2 parents 931a312 + a7f59a1 commit 41aa437
Showing 1 changed file with 91 additions and 0 deletions.
91 changes: 91 additions & 0 deletions examples/model_selection_psql/ms_mlp/train_mpi.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
#


from singa import singa_wrap as singa
from singa import opt
from singa import tensor
import argparse
import train_cnn

singa_dtype = {"float16": tensor.float16, "float32": tensor.float32}

if __name__ == '__main__':
# Use argparse to get command config: max_epoch, model, data, etc., for single gpu training
parser = argparse.ArgumentParser(
description='Training using the autograd and graph.')
parser.add_argument('model',
choices=['cnn', 'resnet', 'xceptionnet', 'mlp'],
default='cnn')
parser.add_argument('data', choices=['mnist', 'cifar10', 'cifar100'], default='mnist')
parser.add_argument('-p',
choices=['float32', 'float16'],
default='float32',
dest='precision')
parser.add_argument('-m',
'--max-epoch',
default=10,
type=int,
help='maximum epochs',
dest='max_epoch')
parser.add_argument('-b',
'--batch-size',
default=64,
type=int,
help='batch size',
dest='batch_size')
parser.add_argument('-l',
'--learning-rate',
default=0.005,
type=float,
help='initial learning rate',
dest='lr')
parser.add_argument('-d',
'--dist-option',
default='plain',
choices=['plain','half','partialUpdate','sparseTopK','sparseThreshold'],
help='distibuted training options',
dest='dist_option') # currently partialUpdate support graph=False only
parser.add_argument('-s',
'--sparsification',
default='0.05',
type=float,
help='the sparsity parameter used for sparsification, between 0 to 1',
dest='spars')
parser.add_argument('-g',
'--disable-graph',
default='True',
action='store_false',
help='disable graph',
dest='graph')
parser.add_argument('-v',
'--log-verbosity',
default=0,
type=int,
help='logging verbosity',
dest='verbosity')

args = parser.parse_args()

sgd = opt.SGD(lr=args.lr, momentum=0.9, weight_decay=1e-5, dtype=singa_dtype[args.precision])
sgd = opt.DistOpt(sgd)

train_cnn.run(sgd.global_rank, sgd.world_size, sgd.local_rank, args.max_epoch,
args.batch_size, args.model, args.data, sgd, args.graph,
args.verbosity, args.dist_option, args.spars, args.precision)

0 comments on commit 41aa437

Please sign in to comment.