-
Notifications
You must be signed in to change notification settings - Fork 58
/
tf1_mnist_keras.py
139 lines (115 loc) · 4.87 KB
/
tf1_mnist_keras.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
#!/usr/bin/env python3
# This example is inspired by https://www.tensorflow.org/guide/keras/train_and_evaluate
#
# KungFu requires users to make the following changes:
# 1. KungFu provides distributed optimizers that can wrap the original optimizer.
# The distributed optimizer defines how local gradients and model weights are synchronized.
# 2. (Optional) In a distributed training setting, the training dataset is often partitioned.
# 3. (Optional) Scaling the learning rate of your local optimizer
#
# Command to run this script:
# $ ./bin/kungfu-run -np 4 python3 examples/mnist_keras.py --n-epochs 10
import argparse
import kungfu as kf
import tensorflow as tf
from kungfu.python import current_cluster_size, current_rank
from kungfu.tensorflow.initializer import BroadcastGlobalVariablesCallback
def load_dataset():
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()
# preprocess the mnist dataset
x_train = x_train.reshape(60000, 784).astype('float32') / 255
x_test = x_test.reshape(10000, 784).astype('float32') / 255
y_train = y_train.astype('float32')
y_test = y_test.astype('float32')
# create dataset
dataset = dict()
dataset['x_val'] = x_train[-10000:]
dataset['y_val'] = y_train[-10000:]
dataset['x_train'] = x_train[:-10000]
dataset['y_train'] = y_train[:-10000]
dataset['x_test'] = x_test
dataset['y_test'] = y_test
return dataset
def build_optimizer(name, n_shards=1):
learning_rate = 0.1
# Scale learning rate according to the level of data parallelism
optimizer = tf.train.GradientDescentOptimizer(learning_rate * n_shards)
# KUNGFU: Wrap the TensorFlow optimizer with KungFu distributed optimizers.
if name == 'sync-sgd':
from kungfu.tensorflow.optimizers import SynchronousSGDOptimizer
return SynchronousSGDOptimizer(optimizer)
elif name == 'async-sgd':
from kungfu.tensorflow.optimizers import PairAveragingOptimizer
return PairAveragingOptimizer(optimizer, fuse_requests=True)
elif name == 'sma':
from kungfu.tensorflow.optimizers import SynchronousAveragingOptimizer
return SynchronousAveragingOptimizer(optimizer)
else:
raise RuntimeError('unknown optimizer: %s' % name)
def build_model(optimizer):
num_classes = 10
# create a model with keras
model = tf.keras.Sequential()
# add two hidden layer
model.add(tf.keras.layers.Dense(64, activation='relu'))
model.add(tf.keras.layers.Dense(64, activation='relu'))
# add a dense layer with number of classes of nodes and softmax
model.add(tf.keras.layers.Dense(num_classes, activation='softmax'))
# compile the model
model.compile(optimizer=optimizer,
loss='sparse_categorical_crossentropy',
metrics=[tf.keras.metrics.SparseCategoricalAccuracy()])
return model
def train_model(model, dataset, n_epochs=1, batch_size=5000):
n_shards = current_cluster_size()
shard_id = current_rank()
train_data_size = len(dataset['x_train'])
# calculate the offset for the data of the KungFu node
shard_size = train_data_size // n_shards
offset = batch_size * shard_id
# extract the data for learning of the KungFu node
x = dataset['x_train'][offset:offset + shard_size]
y = dataset['y_train'][offset:offset + shard_size]
# train the model
model.fit(x,
y,
batch_size=batch_size,
epochs=n_epochs,
callbacks=[BroadcastGlobalVariablesCallback()],
validation_data=(dataset['x_val'], dataset['y_val']),
verbose=2)
def test_model(model, dataset):
test_metrics = model.evaluate(dataset['x_test'], dataset['y_test'])
# print test accuracy
accuracy_index = 1
print('test accuracy: %f' % test_metrics[accuracy_index])
def parse_args():
parser = argparse.ArgumentParser(description='KungFu mnist example.')
parser.add_argument('--kf-optimizer',
type=str,
default='sync-sgd',
help='kungfu optimizer')
parser.add_argument('--n-epochs',
type=int,
default=1,
help='number of epochs')
parser.add_argument('--batch-size',
type=int,
default=50,
help='batch size')
return parser.parse_args()
def main():
# parse arguments from the command line
args = parse_args()
# build the KungFu optimizer
optimizer = build_optimizer(args.kf_optimizer)
# build the Tensorflow model
model = build_model(optimizer)
# load mnist dataset
dataset = load_dataset()
# train the Tensorflow model
train_model(model, dataset, args.n_epochs, args.batch_size)
# test the performance of the Tensorflow model
test_model(model, dataset)
if __name__ == '__main__':
main()